diff --git a/gtsam/discrete/DiscreteTableConditional.cpp b/gtsam/discrete/DiscreteTableConditional.cpp new file mode 100644 index 000000000..b09e2738f --- /dev/null +++ b/gtsam/discrete/DiscreteTableConditional.cpp @@ -0,0 +1,181 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file DiscreteTableConditional.cpp + * @date Dec 22, 2024 + * @author Varun Agrawal + */ + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace std; +using std::pair; +using std::stringstream; +using std::vector; +namespace gtsam { + +/* ************************************************************************** */ +DiscreteTableConditional::DiscreteTableConditional(const size_t nrFrontals, + const TableFactor& f) + : BaseConditional(nrFrontals, DecisionTreeFactor(f.discreteKeys(), ADT())), + sparse_table_((f / (*f.sum(nrFrontals))).sparseTable()) { + // sparse_table_ = sparse_table_.prune(); +} + +/* ************************************************************************** */ +DiscreteTableConditional::DiscreteTableConditional( + size_t nrFrontals, const DiscreteKeys& keys, + const Eigen::SparseVector& potentials) + : BaseConditional(nrFrontals, keys, DecisionTreeFactor(keys, ADT())), + sparse_table_(potentials) {} + +/* ************************************************************************** */ +DiscreteTableConditional::DiscreteTableConditional(const TableFactor& joint, + const TableFactor& marginal) + : BaseConditional(joint.size() - marginal.size(), + joint.discreteKeys() & marginal.discreteKeys(), ADT()), + sparse_table_((joint / marginal).sparseTable()) {} + +/* ************************************************************************** */ +DiscreteTableConditional::DiscreteTableConditional(const TableFactor& joint, + const TableFactor& marginal, + const Ordering& orderedKeys) + : DiscreteTableConditional(joint, marginal) { + keys_.clear(); + keys_.insert(keys_.end(), orderedKeys.begin(), orderedKeys.end()); +} + +/* ************************************************************************** */ +DiscreteTableConditional::DiscreteTableConditional(const Signature& signature) + : BaseConditional(1, DecisionTreeFactor()), + sparse_table_(TableFactor(signature.discreteKeys(), signature.cpt()) + .sparseTable()) {} + +/* ************************************************************************** */ +DiscreteTableConditional DiscreteTableConditional::operator*( + const DiscreteTableConditional& other) const { + // Take union of frontal keys + std::set newFrontals; + for (auto&& key : this->frontals()) newFrontals.insert(key); + for (auto&& key : other.frontals()) newFrontals.insert(key); + + // Check if frontals overlapped + if (nrFrontals() + other.nrFrontals() > newFrontals.size()) + throw std::invalid_argument( + "DiscreteTableConditional::operator* called with overlapping frontal " + "keys."); + + // Now, add cardinalities. + DiscreteKeys discreteKeys; + for (auto&& key : frontals()) + discreteKeys.emplace_back(key, cardinality(key)); + for (auto&& key : other.frontals()) + discreteKeys.emplace_back(key, other.cardinality(key)); + + // Sort + std::sort(discreteKeys.begin(), discreteKeys.end()); + + // Add parents to set, to make them unique + std::set parents; + for (auto&& key : this->parents()) + if (!newFrontals.count(key)) parents.emplace(key, cardinality(key)); + for (auto&& key : other.parents()) + if (!newFrontals.count(key)) parents.emplace(key, other.cardinality(key)); + + // Finally, add parents to keys, in order + for (auto&& dk : parents) discreteKeys.push_back(dk); + + TableFactor a(this->discreteKeys(), this->sparse_table_), + b(other.discreteKeys(), other.sparse_table_); + TableFactor product = a * other; + return DiscreteTableConditional(newFrontals.size(), product); +} + +/* ************************************************************************** */ +void DiscreteTableConditional::print(const string& s, + const KeyFormatter& formatter) const { + cout << s << " P( "; + for (const_iterator it = beginFrontals(); it != endFrontals(); ++it) { + cout << formatter(*it) << " "; + } + if (nrParents()) { + cout << "| "; + for (const_iterator it = beginParents(); it != endParents(); ++it) { + cout << formatter(*it) << " "; + } + } + cout << "):\n"; + // BaseFactor::print("", formatter); + cout << endl; +} + +/* ************************************************************************** */ +bool DiscreteTableConditional::equals(const DiscreteFactor& other, + double tol) const { + if (!dynamic_cast(&other)) { + return false; + } else { + const DiscreteConditional& f( + static_cast(other)); + return DiscreteConditional::equals(f, tol); + } +} + +/* ************************************************************************** */ +TableFactor::shared_ptr DiscreteTableConditional::likelihood( + const DiscreteValues& frontalValues) const { + throw std::runtime_error("Likelihood not implemented"); +} + +/* ****************************************************************************/ +TableFactor::shared_ptr DiscreteTableConditional::likelihood( + size_t frontal) const { + throw std::runtime_error("Likelihood not implemented"); +} + +/* ************************************************************************** */ +size_t DiscreteTableConditional::argmax( + const DiscreteValues& parentsValues) const { + // Initialize + size_t maxValue = 0; + double maxP = 0; + DiscreteValues values = parentsValues; + + assert(nrFrontals() == 1); + Key j = firstFrontalKey(); + for (size_t value = 0; value < cardinality(j); value++) { + values[j] = value; + double pValueS = (*this)(values); + // Update MPE solution if better + if (pValueS > maxP) { + maxP = pValueS; + maxValue = value; + } + } + return maxValue; +} + +} // namespace gtsam diff --git a/gtsam/discrete/DiscreteTableConditional.h b/gtsam/discrete/DiscreteTableConditional.h new file mode 100644 index 000000000..28e35277d --- /dev/null +++ b/gtsam/discrete/DiscreteTableConditional.h @@ -0,0 +1,224 @@ +/* ---------------------------------------------------------------------------- + + * GTSAM Copyright 2010, Georgia Tech Research Corporation, + * Atlanta, Georgia 30332-0415 + * All Rights Reserved + * Authors: Frank Dellaert, et al. (see THANKS for the full author list) + + * See LICENSE for the license information + + * -------------------------------------------------------------------------- */ + +/** + * @file DiscreteTableConditional.h + * @date Dec 22, 2024 + * @author Varun Agrawal + */ + +#pragma once + +#include +#include +#include +#include + +#include +#include +#include + +namespace gtsam { + +/** + * Discrete Conditional Density which uses a SparseTable as the internal + * representation, similar to the TableFactor. + * + * @ingroup discrete + */ +class GTSAM_EXPORT DiscreteTableConditional : public DiscreteConditional { + Eigen::SparseVector sparse_table_; + + public: + // typedefs needed to play nice with gtsam + typedef DiscreteTableConditional This; ///< Typedef to this class + typedef std::shared_ptr shared_ptr; ///< shared_ptr to this class + typedef DiscreteConditional + BaseConditional; ///< Typedef to our conditional base class + + using Values = DiscreteValues; ///< backwards compatibility + + /// @name Standard Constructors + /// @{ + + /// Default constructor needed for serialization. + DiscreteTableConditional() {} + + /// Construct from factor, taking the first `nFrontals` keys as frontals. + DiscreteTableConditional(size_t nFrontals, const TableFactor& f); + + /** + * Construct from DiscreteKeys and SparseVector, taking the first + * `nFrontals` keys as frontals, in the order given. + */ + DiscreteTableConditional(size_t nFrontals, const DiscreteKeys& keys, + const Eigen::SparseVector& potentials); + + /** Construct from signature */ + explicit DiscreteTableConditional(const Signature& signature); + + /** + * Construct from key, parents, and a Signature::Table specifying the + * conditional probability table (CPT) in 00 01 10 11 order. For + * three-valued, it would be 00 01 02 10 11 12 20 21 22, etc.... + * + * Example: DiscreteTableConditional P(D, {B,E}, table); + */ + DiscreteTableConditional(const DiscreteKey& key, const DiscreteKeys& parents, + const Signature::Table& table) + : DiscreteTableConditional(Signature(key, parents, table)) {} + + /** + * Construct from key, parents, and a vector specifying the + * conditional probability table (CPT) in 00 01 10 11 order. For + * three-valued, it would be 00 01 02 10 11 12 20 21 22, etc.... + * + * Example: DiscreteTableConditional P(D, {B,E}, table); + */ + DiscreteTableConditional(const DiscreteKey& key, const DiscreteKeys& parents, + const std::vector& table) + : DiscreteTableConditional( + 1, TableFactor(DiscreteKeys{key} & parents, table)) {} + + /** + * Construct from key, parents, and a string specifying the conditional + * probability table (CPT) in 00 01 10 11 order. For three-valued, it would + * be 00 01 02 10 11 12 20 21 22, etc.... + * + * The string is parsed into a Signature::Table. + * + * Example: DiscreteTableConditional P(D, {B,E}, "9/1 2/8 3/7 1/9"); + */ + DiscreteTableConditional(const DiscreteKey& key, const DiscreteKeys& parents, + const std::string& spec) + : DiscreteTableConditional(Signature(key, parents, spec)) {} + + /// No-parent specialization; can also use DiscreteDistribution. + DiscreteTableConditional(const DiscreteKey& key, const std::string& spec) + : DiscreteTableConditional(Signature(key, {}, spec)) {} + + /** + * @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y) + * Assumes but *does not check* that f(Y)=sum_X f(X,Y). + */ + DiscreteTableConditional(const TableFactor& joint, + const TableFactor& marginal); + + /** + * @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y) + * Assumes but *does not check* that f(Y)=sum_X f(X,Y). + * Makes sure the keys are ordered as given. Does not check orderedKeys. + */ + DiscreteTableConditional(const TableFactor& joint, + const TableFactor& marginal, + const Ordering& orderedKeys); + + /** + * @brief Combine two conditionals, yielding a new conditional with the union + * of the frontal keys, ordered by gtsam::Key. + * + * The two conditionals must make a valid Bayes net fragment, i.e., + * the frontal variables cannot overlap, and must be acyclic: + * Example of correct use: + * P(A,B) = P(A|B) * P(B) + * P(A,B|C) = P(A|B) * P(B|C) + * P(A,B,C) = P(A,B|C) * P(C) + * Example of incorrect use: + * P(A|B) * P(A|C) = ? + * P(A|B) * P(B|A) = ? + * We check for overlapping frontals, but do *not* check for cyclic. + */ + DiscreteTableConditional operator*( + const DiscreteTableConditional& other) const; + + /// @} + /// @name Testable + /// @{ + + /// GTSAM-style print + void print( + const std::string& s = "Discrete Conditional: ", + const KeyFormatter& formatter = DefaultKeyFormatter) const override; + + /// GTSAM-style equals + bool equals(const DiscreteFactor& other, double tol = 1e-9) const override; + + /// @} + /// @name Standard Interface + /// @{ + + /// Log-probability is just -error(x). + double logProbability(const DiscreteValues& x) const { return -error(x); } + + /// print index signature only + void printSignature( + const std::string& s = "Discrete Conditional: ", + const KeyFormatter& formatter = DefaultKeyFormatter) const { + static_cast(this)->print(s, formatter); + } + + /** Convert to a likelihood factor by providing value before bar. */ + TableFactor::shared_ptr likelihood(const DiscreteValues& frontalValues) const; + + /** Single variable version of likelihood. */ + TableFactor::shared_ptr likelihood(size_t frontal) const; + + /** + * @brief Return assignment for single frontal variable that maximizes value. + * @param parentsValues Known assignments for the parents. + * @return maximizing assignment for the frontal variable. + */ + size_t argmax(const DiscreteValues& parentsValues = DiscreteValues()) const; + + /// @} + /// @name Advanced Interface + /// @{ + + /// Return all assignments for frontal variables. + std::vector frontalAssignments() const; + + /// Return all assignments for frontal *and* parent variables. + std::vector allAssignments() const; + + /// @} + /// @name HybridValues methods. + /// @{ + + using BaseConditional::operator(); ///< HybridValues version + + /** + * Calculate log-probability log(evaluate(x)) for HybridValues `x`. + * This is actually just -error(x). + */ + double logProbability(const HybridValues& x) const override { + return -error(x); + } + + /// @} + + private: +#if GTSAM_ENABLE_BOOST_SERIALIZATION + /** Serialization function */ + friend class boost::serialization::access; + template + void serialize(Archive& ar, const unsigned int /*version*/) { + ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional); + } +#endif +}; +// DiscreteTableConditional + +// traits +template <> +struct traits + : public Testable {}; + +} // namespace gtsam