349 lines
11 KiB
C++
349 lines
11 KiB
C++
/* ----------------------------------------------------------------------------
|
|
|
|
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
|
* Atlanta, Georgia 30332-0415
|
|
* All Rights Reserved
|
|
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
|
|
|
* See LICENSE for the license information
|
|
|
|
* -------------------------------------------------------------------------- */
|
|
|
|
/**
|
|
* @file TableFactor.h
|
|
* @date May 4, 2023
|
|
* @author Yoonwoo Kim, Varun Agrawal
|
|
*/
|
|
|
|
#pragma once
|
|
|
|
#include <gtsam/discrete/DiscreteFactor.h>
|
|
#include <gtsam/discrete/DiscreteKey.h>
|
|
#include <gtsam/discrete/Ring.h>
|
|
#include <gtsam/inference/Ordering.h>
|
|
|
|
#include <Eigen/Sparse>
|
|
#include <algorithm>
|
|
#include <map>
|
|
#include <memory>
|
|
#include <stdexcept>
|
|
#include <string>
|
|
#include <utility>
|
|
#include <vector>
|
|
|
|
namespace gtsam {
|
|
|
|
class DiscreteConditional;
|
|
class HybridValues;
|
|
|
|
/**
|
|
* A discrete probabilistic factor optimized for sparsity.
|
|
* Uses sparse_table_ to store only the nonzero probabilities.
|
|
* Computes the assigned value for the key using the ordering which the
|
|
* nonzero probabilties are stored in. (lazy cartesian product)
|
|
*
|
|
* @ingroup discrete
|
|
*/
|
|
class GTSAM_EXPORT TableFactor : public DiscreteFactor {
|
|
protected:
|
|
/// SparseVector of nonzero probabilities.
|
|
Eigen::SparseVector<double> sparse_table_;
|
|
|
|
private:
|
|
/// Map of Keys and their denominators used in keyValueForIndex.
|
|
std::map<Key, size_t> denominators_;
|
|
/// Sorted DiscreteKeys to use internally.
|
|
DiscreteKeys sorted_dkeys_;
|
|
|
|
/**
|
|
* @brief Uses lazy cartesian product to find nth entry in the cartesian
|
|
* product of arrays in O(1)
|
|
* Example)
|
|
* v0 | v1 | val
|
|
* 0 | 0 | 10
|
|
* 0 | 1 | 21
|
|
* 1 | 0 | 32
|
|
* 1 | 1 | 43
|
|
* keyValueForIndex(v1, 2) = 0
|
|
* @param target_key nth entry's key to find out its assigned value
|
|
* @param index nth entry in the sparse vector
|
|
* @return TableFactor
|
|
*/
|
|
size_t keyValueForIndex(Key target_key, uint64_t index) const;
|
|
|
|
/**
|
|
* @brief Return ith key in keys_ as a DiscreteKey
|
|
* @param i ith key in keys_
|
|
* @return DiscreteKey
|
|
*/
|
|
DiscreteKey discreteKey(size_t i) const {
|
|
return DiscreteKey(keys_[i], cardinalities_.at(keys_[i]));
|
|
}
|
|
|
|
/**
|
|
* Convert probability table given as doubles to SparseVector.
|
|
* Example: {0, 1, 1, 0, 0, 1, 0} -> values: {1, 1, 1}, indices: {1, 2, 5}
|
|
*/
|
|
static Eigen::SparseVector<double> Convert(const DiscreteKeys& keys,
|
|
const std::vector<double>& table);
|
|
|
|
/// Convert probability table given as string to SparseVector.
|
|
static Eigen::SparseVector<double> Convert(const DiscreteKeys& keys,
|
|
const std::string& table);
|
|
|
|
public:
|
|
// typedefs needed to play nice with gtsam
|
|
typedef TableFactor This;
|
|
typedef DiscreteFactor Base; ///< Typedef to base class
|
|
typedef std::shared_ptr<TableFactor> shared_ptr;
|
|
typedef Eigen::SparseVector<double>::InnerIterator SparseIt;
|
|
typedef std::vector<std::pair<DiscreteValues, double>> AssignValList;
|
|
|
|
public:
|
|
/// @name Standard Constructors
|
|
/// @{
|
|
|
|
/** Default constructor for I/O */
|
|
TableFactor();
|
|
|
|
/** Constructor from DiscreteKeys and TableFactor */
|
|
TableFactor(const DiscreteKeys& keys, const TableFactor& potentials);
|
|
|
|
/** Constructor from sparse_table */
|
|
TableFactor(const DiscreteKeys& keys,
|
|
const Eigen::SparseVector<double>& table);
|
|
|
|
/** Constructor from doubles */
|
|
TableFactor(const DiscreteKeys& keys, const std::vector<double>& table)
|
|
: TableFactor(keys, Convert(keys, table)) {}
|
|
|
|
/** Constructor from string */
|
|
TableFactor(const DiscreteKeys& keys, const std::string& table)
|
|
: TableFactor(keys, Convert(keys, table)) {}
|
|
|
|
/// Single-key specialization
|
|
template <class SOURCE>
|
|
TableFactor(const DiscreteKey& key, SOURCE table)
|
|
: TableFactor(DiscreteKeys{key}, table) {}
|
|
|
|
/// Single-key specialization, with vector of doubles.
|
|
TableFactor(const DiscreteKey& key, const std::vector<double>& row)
|
|
: TableFactor(DiscreteKeys{key}, row) {}
|
|
|
|
/// Constructor from DecisionTreeFactor
|
|
TableFactor(const DiscreteKeys& keys, const DecisionTreeFactor& dtf);
|
|
TableFactor(const DecisionTreeFactor& dtf);
|
|
|
|
/// Constructor from DecisionTree<Key, double>/AlgebraicDecisionTree
|
|
TableFactor(const DiscreteKeys& keys, const DecisionTree<Key, double>& dtree);
|
|
|
|
/** Construct from a DiscreteConditional type */
|
|
explicit TableFactor(const DiscreteConditional& c);
|
|
|
|
/// @}
|
|
/// @name Testable
|
|
/// @{
|
|
|
|
/// equality
|
|
bool equals(const DiscreteFactor& other, double tol = 1e-9) const override;
|
|
|
|
// print
|
|
void print(
|
|
const std::string& s = "TableFactor:\n",
|
|
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
|
|
|
|
// /// @}
|
|
// /// @name Standard Interface
|
|
// /// @{
|
|
|
|
/// Evaluate probability distribution, is just look up in TableFactor.
|
|
double evaluate(const Assignment<Key>& values) const override;
|
|
|
|
/// Calculate error for DiscreteValues `x`, is -log(probability).
|
|
double error(const DiscreteValues& values) const override;
|
|
|
|
/// multiply two TableFactors
|
|
TableFactor operator*(const TableFactor& f) const {
|
|
return apply(f, Ring::mul);
|
|
};
|
|
|
|
/// multiply with DecisionTreeFactor
|
|
DecisionTreeFactor operator*(const DecisionTreeFactor& f) const override;
|
|
|
|
static double safe_div(const double& a, const double& b);
|
|
|
|
/// divide by factor f (safely)
|
|
TableFactor operator/(const TableFactor& f) const {
|
|
return apply(f, safe_div);
|
|
}
|
|
|
|
/// Convert into a decisiontree
|
|
DecisionTreeFactor toDecisionTreeFactor() const override;
|
|
|
|
/// Create a TableFactor that is a subset of this TableFactor
|
|
TableFactor choose(const DiscreteValues assignments,
|
|
DiscreteKeys parent_keys) const;
|
|
|
|
/// Create new factor by summing all values with the same separator values
|
|
shared_ptr sum(size_t nrFrontals) const {
|
|
return combine(nrFrontals, Ring::add);
|
|
}
|
|
|
|
/// Create new factor by summing all values with the same separator values
|
|
shared_ptr sum(const Ordering& keys) const {
|
|
return combine(keys, Ring::add);
|
|
}
|
|
|
|
/// Create new factor by maximizing over all values with the same separator.
|
|
shared_ptr max(size_t nrFrontals) const {
|
|
return combine(nrFrontals, Ring::max);
|
|
}
|
|
|
|
/// Create new factor by maximizing over all values with the same separator.
|
|
shared_ptr max(const Ordering& keys) const {
|
|
return combine(keys, Ring::max);
|
|
}
|
|
|
|
/// @}
|
|
/// @name Advanced Interface
|
|
/// @{
|
|
|
|
/**
|
|
* Apply unary operator `op(*this)` where `op` accepts the discrete value.
|
|
* @param op a unary operator that operates on TableFactor
|
|
*/
|
|
TableFactor apply(Unary op) const;
|
|
/**
|
|
* Apply unary operator `op(*this)` where `op` accepts the discrete assignment
|
|
* and the value at that assignment.
|
|
* @param op a unary operator that operates on TableFactor
|
|
*/
|
|
TableFactor apply(UnaryAssignment op) const;
|
|
|
|
/**
|
|
* Apply binary operator (*this) "op" f
|
|
* @param f the second argument for op
|
|
* @param op a binary operator that operates on TableFactor
|
|
*/
|
|
TableFactor apply(const TableFactor& f, Binary op) const;
|
|
|
|
/**
|
|
* Return keys in contract mode.
|
|
*
|
|
* Modes are each of the dimensions of a sparse tensor,
|
|
* and the contract modes represent which dimensions will
|
|
* be involved in contraction (aka tensor multiplication).
|
|
*/
|
|
DiscreteKeys contractDkeys(const TableFactor& f) const;
|
|
|
|
/**
|
|
* @brief Return keys in free mode which are the dimensions
|
|
* not involved in the contraction operation.
|
|
*/
|
|
DiscreteKeys freeDkeys(const TableFactor& f) const;
|
|
|
|
/// Return union of DiscreteKeys in two factors.
|
|
DiscreteKeys unionDkeys(const TableFactor& f) const;
|
|
|
|
/// Create unique representation of union modes.
|
|
uint64_t unionRep(const DiscreteKeys& keys, const DiscreteValues& assign,
|
|
const uint64_t idx) const;
|
|
|
|
/// Create a hash map of input factor with assignment of contract modes as
|
|
/// keys and vector of hashed assignment of free modes and value as values.
|
|
std::unordered_map<uint64_t, AssignValList> createMap(
|
|
const DiscreteKeys& contract, const DiscreteKeys& free) const;
|
|
|
|
/// Create unique representation
|
|
uint64_t uniqueRep(const DiscreteKeys& keys, const uint64_t idx) const;
|
|
|
|
/// Create unique representation with DiscreteValues
|
|
uint64_t uniqueRep(const DiscreteValues& assignments) const;
|
|
|
|
/// Find DiscreteValues for corresponding index.
|
|
DiscreteValues findAssignments(const uint64_t idx) const;
|
|
|
|
/// Find value for corresponding DiscreteValues.
|
|
double findValue(const DiscreteValues& values) const;
|
|
|
|
/**
|
|
* Combine frontal variables using binary operator "op"
|
|
* @param nrFrontals nr. of frontal to combine variables in this factor
|
|
* @param op a binary operator that operates on TableFactor
|
|
* @return shared pointer to newly created TableFactor
|
|
*/
|
|
shared_ptr combine(size_t nrFrontals, Binary op) const;
|
|
|
|
/**
|
|
* Combine frontal variables in an Ordering using binary operator "op"
|
|
* @param nrFrontals nr. of frontal to combine variables in this factor
|
|
* @param op a binary operator that operates on TableFactor
|
|
* @return shared pointer to newly created TableFactor
|
|
*/
|
|
shared_ptr combine(const Ordering& keys, Binary op) const;
|
|
|
|
/// Enumerate all values into a map from values to double.
|
|
std::vector<std::pair<DiscreteValues, double>> enumerate() const;
|
|
|
|
/**
|
|
* @brief Prune the decision tree of discrete variables.
|
|
*
|
|
* Pruning will set the values to be "pruned" to 0 indicating a 0
|
|
* probability. An assignment is pruned if it is not in the top
|
|
* `maxNrAssignments` values.
|
|
*
|
|
* A violation can occur if there are more
|
|
* duplicate values than `maxNrAssignments`. A violation here is the need to
|
|
* un-prune the decision tree (e.g. all assignment values are 1.0). We could
|
|
* have another case where some subset of duplicates exist (e.g. for a tree
|
|
* with 8 assignments we have 1, 1, 1, 1, 0.8, 0.7, 0.6, 0.5), but this is
|
|
* not a violation since the for `maxNrAssignments=5` the top values are (1,
|
|
* 0.8).
|
|
*
|
|
* @param maxNrAssignments The maximum number of assignments to keep.
|
|
* @return TableFactor
|
|
*/
|
|
TableFactor prune(size_t maxNrAssignments) const;
|
|
|
|
/// @}
|
|
/// @name Wrapper support
|
|
/// @{
|
|
|
|
/**
|
|
* @brief Render as markdown table
|
|
*
|
|
* @param keyFormatter GTSAM-style Key formatter.
|
|
* @param names optional, category names corresponding to choices.
|
|
* @return std::string a markdown string.
|
|
*/
|
|
std::string markdown(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
|
const Names& names = {}) const override;
|
|
|
|
/**
|
|
* @brief Render as html table
|
|
*
|
|
* @param keyFormatter GTSAM-style Key formatter.
|
|
* @param names optional, category names corresponding to choices.
|
|
* @return std::string a html string.
|
|
*/
|
|
std::string html(const KeyFormatter& keyFormatter = DefaultKeyFormatter,
|
|
const Names& names = {}) const override;
|
|
|
|
/// @}
|
|
/// @name HybridValues methods.
|
|
/// @{
|
|
|
|
/**
|
|
* Calculate error for HybridValues `x`, is -log(probability)
|
|
* Simply dispatches to DiscreteValues version.
|
|
*/
|
|
double error(const HybridValues& values) const override;
|
|
|
|
/// @}
|
|
};
|
|
|
|
// traits
|
|
template <>
|
|
struct traits<TableFactor> : public Testable<TableFactor> {};
|
|
} // namespace gtsam
|