186 lines
5.3 KiB
C++
186 lines
5.3 KiB
C++
/* ----------------------------------------------------------------------------
|
|
|
|
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
|
* Atlanta, Georgia 30332-0415
|
|
* All Rights Reserved
|
|
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
|
|
|
* See LICENSE for the license information
|
|
|
|
* -------------------------------------------------------------------------- */
|
|
|
|
/**
|
|
* @file TableDistribution.h
|
|
* @date Dec 22, 2024
|
|
* @author Varun Agrawal
|
|
*/
|
|
|
|
#pragma once
|
|
|
|
#include <gtsam/discrete/DiscreteConditional.h>
|
|
#include <gtsam/discrete/TableFactor.h>
|
|
#include <gtsam/inference/Conditional-inst.h>
|
|
|
|
#include <memory>
|
|
#include <string>
|
|
#include <vector>
|
|
|
|
namespace gtsam {
|
|
|
|
/**
|
|
* Distribution which uses a SparseVector as the internal
|
|
* representation, similar to the TableFactor.
|
|
*
|
|
* This is primarily used in the case when we have a clique in the BayesTree
|
|
* which consists of all the discrete variables, e.g. in hybrid elimination.
|
|
*
|
|
* @ingroup discrete
|
|
*/
|
|
class GTSAM_EXPORT TableDistribution : public DiscreteConditional {
|
|
private:
|
|
TableFactor table_;
|
|
|
|
typedef Eigen::SparseVector<double>::InnerIterator SparseIt;
|
|
|
|
public:
|
|
// typedefs needed to play nice with gtsam
|
|
typedef TableDistribution This; ///< Typedef to this class
|
|
typedef std::shared_ptr<This> shared_ptr; ///< shared_ptr to this class
|
|
typedef DiscreteConditional
|
|
BaseConditional; ///< Typedef to our conditional base class
|
|
|
|
using Values = DiscreteValues; ///< backwards compatibility
|
|
|
|
/// @name Standard Constructors
|
|
/// @{
|
|
|
|
/// Default constructor needed for serialization.
|
|
TableDistribution() {}
|
|
|
|
/// Construct from TableFactor.
|
|
TableDistribution(const TableFactor& f);
|
|
|
|
/**
|
|
* Construct from DiscreteKeys and std::vector.
|
|
*/
|
|
TableDistribution(const DiscreteKeys& keys,
|
|
const std::vector<double>& potentials);
|
|
|
|
/**
|
|
* Construct from single DiscreteKey and std::vector.
|
|
*/
|
|
TableDistribution(const DiscreteKey& key,
|
|
const std::vector<double>& potentials)
|
|
: TableDistribution(DiscreteKeys(key), potentials) {}
|
|
|
|
/**
|
|
* Construct from DiscreteKey and std::string.
|
|
*/
|
|
TableDistribution(const DiscreteKeys& keys, const std::string& potentials);
|
|
|
|
/**
|
|
* Construct from single DiscreteKey and std::string.
|
|
*/
|
|
TableDistribution(const DiscreteKey& key, const std::string& potentials)
|
|
: TableDistribution(DiscreteKeys(key), potentials) {}
|
|
|
|
/// @}
|
|
/// @name Testable
|
|
/// @{
|
|
|
|
/// GTSAM-style print
|
|
void print(
|
|
const std::string& s = "Table Distribution: ",
|
|
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
|
|
|
|
/// GTSAM-style equals
|
|
bool equals(const DiscreteFactor& other, double tol = 1e-9) const override;
|
|
|
|
/// @}
|
|
/// @name Standard Interface
|
|
/// @{
|
|
|
|
/// Return the underlying TableFactor
|
|
TableFactor table() const { return table_; }
|
|
|
|
using BaseConditional::evaluate; // HybridValues version
|
|
|
|
/// Evaluate the conditional given the values.
|
|
virtual double evaluate(const Assignment<Key>& values) const override {
|
|
return table_.evaluate(values);
|
|
}
|
|
|
|
/// Create new factor by summing all values with the same separator values
|
|
DiscreteFactor::shared_ptr sum(size_t nrFrontals) const override;
|
|
|
|
/// Create new factor by summing all values with the same separator values
|
|
DiscreteFactor::shared_ptr sum(const Ordering& keys) const override;
|
|
|
|
/// Find the maximum value in the factor.
|
|
double max() const override { return table_.max(); }
|
|
|
|
/// Create new factor by maximizing over all values with the same separator.
|
|
DiscreteFactor::shared_ptr max(size_t nrFrontals) const override;
|
|
|
|
/// Create new factor by maximizing over all values with the same separator.
|
|
DiscreteFactor::shared_ptr max(const Ordering& keys) const override;
|
|
|
|
/// Multiply by scalar s
|
|
DiscreteFactor::shared_ptr operator*(double s) const override;
|
|
|
|
/// divide by DiscreteFactor::shared_ptr f (safely)
|
|
DiscreteFactor::shared_ptr operator/(
|
|
const DiscreteFactor::shared_ptr& f) const override;
|
|
|
|
/**
|
|
* @brief Return assignment that maximizes value.
|
|
*
|
|
* @return maximizing assignment for the variables.
|
|
*/
|
|
DiscreteValues argmax() const;
|
|
|
|
/**
|
|
* sample
|
|
* @param parentsValues Known values of the parents
|
|
* @param rng Pseudo random number generator
|
|
* @return sample from conditional
|
|
*/
|
|
virtual size_t sample(const DiscreteValues& parentsValues,
|
|
std::mt19937_64* rng = nullptr) const override;
|
|
|
|
/// @}
|
|
/// @name Advanced Interface
|
|
/// @{
|
|
|
|
/// Prune the conditional
|
|
virtual void prune(size_t maxNrAssignments) override;
|
|
|
|
/// Get a DecisionTreeFactor representation.
|
|
DecisionTreeFactor toDecisionTreeFactor() const override {
|
|
return table_.toDecisionTreeFactor();
|
|
}
|
|
|
|
/// Get the number of non-zero values.
|
|
uint64_t nrValues() const override { return table_.sparseTable().nonZeros(); }
|
|
|
|
/// @}
|
|
|
|
private:
|
|
#if GTSAM_ENABLE_BOOST_SERIALIZATION
|
|
/** Serialization function */
|
|
friend class boost::serialization::access;
|
|
template <class Archive>
|
|
void serialize(Archive& ar, const unsigned int /*version*/) {
|
|
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional);
|
|
ar& BOOST_SERIALIZATION_NVP(table_);
|
|
}
|
|
#endif
|
|
};
|
|
// TableDistribution
|
|
|
|
// traits
|
|
template <>
|
|
struct traits<TableDistribution> : public Testable<TableDistribution> {};
|
|
|
|
} // namespace gtsam
|