initial DiscreteTableConditional
parent
34fba6823a
commit
de652eafc2
|
|
@ -0,0 +1,181 @@
|
||||||
|
/* ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||||
|
* Atlanta, Georgia 30332-0415
|
||||||
|
* All Rights Reserved
|
||||||
|
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||||
|
|
||||||
|
* See LICENSE for the license information
|
||||||
|
|
||||||
|
* -------------------------------------------------------------------------- */
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @file DiscreteTableConditional.cpp
|
||||||
|
* @date Dec 22, 2024
|
||||||
|
* @author Varun Agrawal
|
||||||
|
*/
|
||||||
|
|
||||||
|
#include <gtsam/base/Testable.h>
|
||||||
|
#include <gtsam/base/debug.h>
|
||||||
|
#include <gtsam/discrete/DiscreteTableConditional.h>
|
||||||
|
#include <gtsam/discrete/Ring.h>
|
||||||
|
#include <gtsam/discrete/Signature.h>
|
||||||
|
#include <gtsam/hybrid/HybridValues.h>
|
||||||
|
|
||||||
|
#include <algorithm>
|
||||||
|
#include <cassert>
|
||||||
|
#include <random>
|
||||||
|
#include <set>
|
||||||
|
#include <stdexcept>
|
||||||
|
#include <string>
|
||||||
|
#include <utility>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
using namespace std;
|
||||||
|
using std::pair;
|
||||||
|
using std::stringstream;
|
||||||
|
using std::vector;
|
||||||
|
namespace gtsam {
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
DiscreteTableConditional::DiscreteTableConditional(const size_t nrFrontals,
|
||||||
|
const TableFactor& f)
|
||||||
|
: BaseConditional(nrFrontals, DecisionTreeFactor(f.discreteKeys(), ADT())),
|
||||||
|
sparse_table_((f / (*f.sum(nrFrontals))).sparseTable()) {
|
||||||
|
// sparse_table_ = sparse_table_.prune();
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
DiscreteTableConditional::DiscreteTableConditional(
|
||||||
|
size_t nrFrontals, const DiscreteKeys& keys,
|
||||||
|
const Eigen::SparseVector<double>& potentials)
|
||||||
|
: BaseConditional(nrFrontals, keys, DecisionTreeFactor(keys, ADT())),
|
||||||
|
sparse_table_(potentials) {}
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
DiscreteTableConditional::DiscreteTableConditional(const TableFactor& joint,
|
||||||
|
const TableFactor& marginal)
|
||||||
|
: BaseConditional(joint.size() - marginal.size(),
|
||||||
|
joint.discreteKeys() & marginal.discreteKeys(), ADT()),
|
||||||
|
sparse_table_((joint / marginal).sparseTable()) {}
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
DiscreteTableConditional::DiscreteTableConditional(const TableFactor& joint,
|
||||||
|
const TableFactor& marginal,
|
||||||
|
const Ordering& orderedKeys)
|
||||||
|
: DiscreteTableConditional(joint, marginal) {
|
||||||
|
keys_.clear();
|
||||||
|
keys_.insert(keys_.end(), orderedKeys.begin(), orderedKeys.end());
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
DiscreteTableConditional::DiscreteTableConditional(const Signature& signature)
|
||||||
|
: BaseConditional(1, DecisionTreeFactor()),
|
||||||
|
sparse_table_(TableFactor(signature.discreteKeys(), signature.cpt())
|
||||||
|
.sparseTable()) {}
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
DiscreteTableConditional DiscreteTableConditional::operator*(
|
||||||
|
const DiscreteTableConditional& other) const {
|
||||||
|
// Take union of frontal keys
|
||||||
|
std::set<Key> newFrontals;
|
||||||
|
for (auto&& key : this->frontals()) newFrontals.insert(key);
|
||||||
|
for (auto&& key : other.frontals()) newFrontals.insert(key);
|
||||||
|
|
||||||
|
// Check if frontals overlapped
|
||||||
|
if (nrFrontals() + other.nrFrontals() > newFrontals.size())
|
||||||
|
throw std::invalid_argument(
|
||||||
|
"DiscreteTableConditional::operator* called with overlapping frontal "
|
||||||
|
"keys.");
|
||||||
|
|
||||||
|
// Now, add cardinalities.
|
||||||
|
DiscreteKeys discreteKeys;
|
||||||
|
for (auto&& key : frontals())
|
||||||
|
discreteKeys.emplace_back(key, cardinality(key));
|
||||||
|
for (auto&& key : other.frontals())
|
||||||
|
discreteKeys.emplace_back(key, other.cardinality(key));
|
||||||
|
|
||||||
|
// Sort
|
||||||
|
std::sort(discreteKeys.begin(), discreteKeys.end());
|
||||||
|
|
||||||
|
// Add parents to set, to make them unique
|
||||||
|
std::set<DiscreteKey> parents;
|
||||||
|
for (auto&& key : this->parents())
|
||||||
|
if (!newFrontals.count(key)) parents.emplace(key, cardinality(key));
|
||||||
|
for (auto&& key : other.parents())
|
||||||
|
if (!newFrontals.count(key)) parents.emplace(key, other.cardinality(key));
|
||||||
|
|
||||||
|
// Finally, add parents to keys, in order
|
||||||
|
for (auto&& dk : parents) discreteKeys.push_back(dk);
|
||||||
|
|
||||||
|
TableFactor a(this->discreteKeys(), this->sparse_table_),
|
||||||
|
b(other.discreteKeys(), other.sparse_table_);
|
||||||
|
TableFactor product = a * other;
|
||||||
|
return DiscreteTableConditional(newFrontals.size(), product);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
void DiscreteTableConditional::print(const string& s,
|
||||||
|
const KeyFormatter& formatter) const {
|
||||||
|
cout << s << " P( ";
|
||||||
|
for (const_iterator it = beginFrontals(); it != endFrontals(); ++it) {
|
||||||
|
cout << formatter(*it) << " ";
|
||||||
|
}
|
||||||
|
if (nrParents()) {
|
||||||
|
cout << "| ";
|
||||||
|
for (const_iterator it = beginParents(); it != endParents(); ++it) {
|
||||||
|
cout << formatter(*it) << " ";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cout << "):\n";
|
||||||
|
// BaseFactor::print("", formatter);
|
||||||
|
cout << endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
bool DiscreteTableConditional::equals(const DiscreteFactor& other,
|
||||||
|
double tol) const {
|
||||||
|
if (!dynamic_cast<const DiscreteConditional*>(&other)) {
|
||||||
|
return false;
|
||||||
|
} else {
|
||||||
|
const DiscreteConditional& f(
|
||||||
|
static_cast<const DiscreteConditional&>(other));
|
||||||
|
return DiscreteConditional::equals(f, tol);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
TableFactor::shared_ptr DiscreteTableConditional::likelihood(
|
||||||
|
const DiscreteValues& frontalValues) const {
|
||||||
|
throw std::runtime_error("Likelihood not implemented");
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ****************************************************************************/
|
||||||
|
TableFactor::shared_ptr DiscreteTableConditional::likelihood(
|
||||||
|
size_t frontal) const {
|
||||||
|
throw std::runtime_error("Likelihood not implemented");
|
||||||
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************** */
|
||||||
|
size_t DiscreteTableConditional::argmax(
|
||||||
|
const DiscreteValues& parentsValues) const {
|
||||||
|
// Initialize
|
||||||
|
size_t maxValue = 0;
|
||||||
|
double maxP = 0;
|
||||||
|
DiscreteValues values = parentsValues;
|
||||||
|
|
||||||
|
assert(nrFrontals() == 1);
|
||||||
|
Key j = firstFrontalKey();
|
||||||
|
for (size_t value = 0; value < cardinality(j); value++) {
|
||||||
|
values[j] = value;
|
||||||
|
double pValueS = (*this)(values);
|
||||||
|
// Update MPE solution if better
|
||||||
|
if (pValueS > maxP) {
|
||||||
|
maxP = pValueS;
|
||||||
|
maxValue = value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return maxValue;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gtsam
|
||||||
|
|
@ -0,0 +1,224 @@
|
||||||
|
/* ----------------------------------------------------------------------------
|
||||||
|
|
||||||
|
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||||
|
* Atlanta, Georgia 30332-0415
|
||||||
|
* All Rights Reserved
|
||||||
|
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||||
|
|
||||||
|
* See LICENSE for the license information
|
||||||
|
|
||||||
|
* -------------------------------------------------------------------------- */
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @file DiscreteTableConditional.h
|
||||||
|
* @date Dec 22, 2024
|
||||||
|
* @author Varun Agrawal
|
||||||
|
*/
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <gtsam/discrete/DiscreteConditional.h>
|
||||||
|
#include <gtsam/discrete/Signature.h>
|
||||||
|
#include <gtsam/discrete/TableFactor.h>
|
||||||
|
#include <gtsam/inference/Conditional-inst.h>
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
namespace gtsam {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Discrete Conditional Density which uses a SparseTable as the internal
|
||||||
|
* representation, similar to the TableFactor.
|
||||||
|
*
|
||||||
|
* @ingroup discrete
|
||||||
|
*/
|
||||||
|
class GTSAM_EXPORT DiscreteTableConditional : public DiscreteConditional {
|
||||||
|
Eigen::SparseVector<double> sparse_table_;
|
||||||
|
|
||||||
|
public:
|
||||||
|
// typedefs needed to play nice with gtsam
|
||||||
|
typedef DiscreteTableConditional This; ///< Typedef to this class
|
||||||
|
typedef std::shared_ptr<This> shared_ptr; ///< shared_ptr to this class
|
||||||
|
typedef DiscreteConditional
|
||||||
|
BaseConditional; ///< Typedef to our conditional base class
|
||||||
|
|
||||||
|
using Values = DiscreteValues; ///< backwards compatibility
|
||||||
|
|
||||||
|
/// @name Standard Constructors
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
/// Default constructor needed for serialization.
|
||||||
|
DiscreteTableConditional() {}
|
||||||
|
|
||||||
|
/// Construct from factor, taking the first `nFrontals` keys as frontals.
|
||||||
|
DiscreteTableConditional(size_t nFrontals, const TableFactor& f);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Construct from DiscreteKeys and SparseVector, taking the first
|
||||||
|
* `nFrontals` keys as frontals, in the order given.
|
||||||
|
*/
|
||||||
|
DiscreteTableConditional(size_t nFrontals, const DiscreteKeys& keys,
|
||||||
|
const Eigen::SparseVector<double>& potentials);
|
||||||
|
|
||||||
|
/** Construct from signature */
|
||||||
|
explicit DiscreteTableConditional(const Signature& signature);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Construct from key, parents, and a Signature::Table specifying the
|
||||||
|
* conditional probability table (CPT) in 00 01 10 11 order. For
|
||||||
|
* three-valued, it would be 00 01 02 10 11 12 20 21 22, etc....
|
||||||
|
*
|
||||||
|
* Example: DiscreteTableConditional P(D, {B,E}, table);
|
||||||
|
*/
|
||||||
|
DiscreteTableConditional(const DiscreteKey& key, const DiscreteKeys& parents,
|
||||||
|
const Signature::Table& table)
|
||||||
|
: DiscreteTableConditional(Signature(key, parents, table)) {}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Construct from key, parents, and a vector<double> specifying the
|
||||||
|
* conditional probability table (CPT) in 00 01 10 11 order. For
|
||||||
|
* three-valued, it would be 00 01 02 10 11 12 20 21 22, etc....
|
||||||
|
*
|
||||||
|
* Example: DiscreteTableConditional P(D, {B,E}, table);
|
||||||
|
*/
|
||||||
|
DiscreteTableConditional(const DiscreteKey& key, const DiscreteKeys& parents,
|
||||||
|
const std::vector<double>& table)
|
||||||
|
: DiscreteTableConditional(
|
||||||
|
1, TableFactor(DiscreteKeys{key} & parents, table)) {}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Construct from key, parents, and a string specifying the conditional
|
||||||
|
* probability table (CPT) in 00 01 10 11 order. For three-valued, it would
|
||||||
|
* be 00 01 02 10 11 12 20 21 22, etc....
|
||||||
|
*
|
||||||
|
* The string is parsed into a Signature::Table.
|
||||||
|
*
|
||||||
|
* Example: DiscreteTableConditional P(D, {B,E}, "9/1 2/8 3/7 1/9");
|
||||||
|
*/
|
||||||
|
DiscreteTableConditional(const DiscreteKey& key, const DiscreteKeys& parents,
|
||||||
|
const std::string& spec)
|
||||||
|
: DiscreteTableConditional(Signature(key, parents, spec)) {}
|
||||||
|
|
||||||
|
/// No-parent specialization; can also use DiscreteDistribution.
|
||||||
|
DiscreteTableConditional(const DiscreteKey& key, const std::string& spec)
|
||||||
|
: DiscreteTableConditional(Signature(key, {}, spec)) {}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y)
|
||||||
|
* Assumes but *does not check* that f(Y)=sum_X f(X,Y).
|
||||||
|
*/
|
||||||
|
DiscreteTableConditional(const TableFactor& joint,
|
||||||
|
const TableFactor& marginal);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief construct P(X|Y) = f(X,Y)/f(Y) from f(X,Y) and f(Y)
|
||||||
|
* Assumes but *does not check* that f(Y)=sum_X f(X,Y).
|
||||||
|
* Makes sure the keys are ordered as given. Does not check orderedKeys.
|
||||||
|
*/
|
||||||
|
DiscreteTableConditional(const TableFactor& joint,
|
||||||
|
const TableFactor& marginal,
|
||||||
|
const Ordering& orderedKeys);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Combine two conditionals, yielding a new conditional with the union
|
||||||
|
* of the frontal keys, ordered by gtsam::Key.
|
||||||
|
*
|
||||||
|
* The two conditionals must make a valid Bayes net fragment, i.e.,
|
||||||
|
* the frontal variables cannot overlap, and must be acyclic:
|
||||||
|
* Example of correct use:
|
||||||
|
* P(A,B) = P(A|B) * P(B)
|
||||||
|
* P(A,B|C) = P(A|B) * P(B|C)
|
||||||
|
* P(A,B,C) = P(A,B|C) * P(C)
|
||||||
|
* Example of incorrect use:
|
||||||
|
* P(A|B) * P(A|C) = ?
|
||||||
|
* P(A|B) * P(B|A) = ?
|
||||||
|
* We check for overlapping frontals, but do *not* check for cyclic.
|
||||||
|
*/
|
||||||
|
DiscreteTableConditional operator*(
|
||||||
|
const DiscreteTableConditional& other) const;
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
/// @name Testable
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
/// GTSAM-style print
|
||||||
|
void print(
|
||||||
|
const std::string& s = "Discrete Conditional: ",
|
||||||
|
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
|
||||||
|
|
||||||
|
/// GTSAM-style equals
|
||||||
|
bool equals(const DiscreteFactor& other, double tol = 1e-9) const override;
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
/// @name Standard Interface
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
/// Log-probability is just -error(x).
|
||||||
|
double logProbability(const DiscreteValues& x) const { return -error(x); }
|
||||||
|
|
||||||
|
/// print index signature only
|
||||||
|
void printSignature(
|
||||||
|
const std::string& s = "Discrete Conditional: ",
|
||||||
|
const KeyFormatter& formatter = DefaultKeyFormatter) const {
|
||||||
|
static_cast<const BaseConditional*>(this)->print(s, formatter);
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Convert to a likelihood factor by providing value before bar. */
|
||||||
|
TableFactor::shared_ptr likelihood(const DiscreteValues& frontalValues) const;
|
||||||
|
|
||||||
|
/** Single variable version of likelihood. */
|
||||||
|
TableFactor::shared_ptr likelihood(size_t frontal) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Return assignment for single frontal variable that maximizes value.
|
||||||
|
* @param parentsValues Known assignments for the parents.
|
||||||
|
* @return maximizing assignment for the frontal variable.
|
||||||
|
*/
|
||||||
|
size_t argmax(const DiscreteValues& parentsValues = DiscreteValues()) const;
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
/// @name Advanced Interface
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
/// Return all assignments for frontal variables.
|
||||||
|
std::vector<DiscreteValues> frontalAssignments() const;
|
||||||
|
|
||||||
|
/// Return all assignments for frontal *and* parent variables.
|
||||||
|
std::vector<DiscreteValues> allAssignments() const;
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
/// @name HybridValues methods.
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
using BaseConditional::operator(); ///< HybridValues version
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Calculate log-probability log(evaluate(x)) for HybridValues `x`.
|
||||||
|
* This is actually just -error(x).
|
||||||
|
*/
|
||||||
|
double logProbability(const HybridValues& x) const override {
|
||||||
|
return -error(x);
|
||||||
|
}
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
|
||||||
|
private:
|
||||||
|
#if GTSAM_ENABLE_BOOST_SERIALIZATION
|
||||||
|
/** Serialization function */
|
||||||
|
friend class boost::serialization::access;
|
||||||
|
template <class Archive>
|
||||||
|
void serialize(Archive& ar, const unsigned int /*version*/) {
|
||||||
|
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
};
|
||||||
|
// DiscreteTableConditional
|
||||||
|
|
||||||
|
// traits
|
||||||
|
template <>
|
||||||
|
struct traits<DiscreteTableConditional>
|
||||||
|
: public Testable<DiscreteTableConditional> {};
|
||||||
|
|
||||||
|
} // namespace gtsam
|
||||||
Loading…
Reference in New Issue