gtsam/gtsam/hybrid/HybridGaussianConditional.h

276 lines
9.4 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file HybridGaussianConditional.h
* @brief A hybrid conditional in the Conditional Linear Gaussian scheme
* @author Fan Jiang
* @author Varun Agrawal
* @author Frank Dellaert
* @date Mar 12, 2022
*/
#pragma once
#include <gtsam/discrete/DecisionTree-inl.h>
#include <gtsam/discrete/DecisionTree.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/hybrid/HybridFactor.h>
#include <gtsam/hybrid/HybridGaussianFactor.h>
#include <gtsam/inference/Conditional.h>
#include <gtsam/linear/GaussianConditional.h>
namespace gtsam {
class HybridValues;
/**
* @brief A conditional of gaussian conditionals indexed by discrete variables,
* as part of a Bayes Network. This is the result of the elimination of a
* continuous variable in a hybrid scheme, such that the remaining variables are
* discrete+continuous.
*
* Represents the conditional density P(X | M, Z) where X is the set of
* continuous random variables, M is the selection of discrete variables
* corresponding to a subset of the Gaussian variables and Z is parent of this
* node .
*
* The probability P(x|y,z,...) is proportional to
* \f$ \sum_i k_i \exp - \frac{1}{2} |R_i x - (d_i - S_i y - T_i z - ...)|^2 \f$
* where i indexes the components and k_i is a component-wise normalization
* constant.
*
* @ingroup hybrid
*/
class GTSAM_EXPORT HybridGaussianConditional
: public HybridGaussianFactor,
public Conditional<HybridGaussianFactor, HybridGaussianConditional> {
public:
using This = HybridGaussianConditional;
using shared_ptr = std::shared_ptr<This>;
using BaseFactor = HybridGaussianFactor;
using BaseConditional = Conditional<BaseFactor, HybridGaussianConditional>;
/// typedef for Decision Tree of Gaussian Conditionals
using Conditionals = DecisionTree<Key, GaussianConditional::shared_ptr>;
private:
///< Negative-log of the normalization constant (log(\sqrt(|2πΣ|))).
///< Take advantage of the neg-log space so everything is a minimization
double negLogConstant_;
/// Flag to indicate if the conditional has been pruned.
bool pruned_ = false;
public:
/// @name Constructors
/// @{
/// Default constructor, mainly for serialization.
HybridGaussianConditional() = default;
/**
* @brief Construct from one discrete key and vector of conditionals.
*
* @param discreteParent Single discrete parent variable
* @param conditionals Vector of conditionals with the same size as the
* cardinality of the discrete parent.
*/
HybridGaussianConditional(
const DiscreteKey &discreteParent,
const std::vector<GaussianConditional::shared_ptr> &conditionals);
/**
* @brief Constructs a HybridGaussianConditional with means mu_i and
* standard deviations sigma_i.
*
* @param discreteParent The discrete parent or "mode" key.
* @param key The key for this conditional variable.
* @param parameters A vector of pairs (mu_i, sigma_i).
*/
HybridGaussianConditional(
const DiscreteKey &discreteParent, Key key,
const std::vector<std::pair<Vector, double>> &parameters);
/**
* @brief Constructs a HybridGaussianConditional with conditional means
* A × parent + b_i and standard deviations sigma_i.
*
* @param discreteParent The discrete parent or "mode" key.
* @param key The key for this conditional variable.
* @param A The matrix A.
* @param parent The key of the parent variable.
* @param parameters A vector of pairs (b_i, sigma_i).
*/
HybridGaussianConditional(
const DiscreteKey &discreteParent, Key key, const Matrix &A, Key parent,
const std::vector<std::pair<Vector, double>> &parameters);
/**
* @brief Constructs a HybridGaussianConditional with conditional means
* A1 × parent1 + A2 × parent2 + b_i and standard deviations sigma_i.
*
* @param discreteParent The discrete parent or "mode" key.
* @param key The key for this conditional variable.
* @param A1 The first matrix.
* @param parent1 The key of the first parent variable.
* @param A2 The second matrix.
* @param parent2 The key of the second parent variable.
* @param parameters A vector of pairs (b_i, sigma_i).
*/
HybridGaussianConditional(
const DiscreteKey &discreteParent, Key key, //
const Matrix &A1, Key parent1, const Matrix &A2, Key parent2,
const std::vector<std::pair<Vector, double>> &parameters);
/**
* @brief Construct from multiple discrete keys and conditional tree.
*
* @param discreteParents the discrete parents. Will be placed last.
* @param conditionals a decision tree of GaussianConditionals. The number of
* conditionals should be C^(number of discrete parents), where C is the
* cardinality of the DiscreteKeys in discreteParents, since the
* discreteParents will be used as the labels in the decision tree.
*/
HybridGaussianConditional(const DiscreteKeys &discreteParents,
const Conditionals &conditionals);
/**
* @brief Construct from multiple discrete keys M and a tree of
* factor/scalar pairs, where the scalar is assumed to be the
* the negative log constant for each assignment m, up to a constant.
*
* @note Will throw if factors are not actually conditionals.
*
* @param discreteParents the discrete parents. Will be placed last.
* @param conditionalPairs Decision tree of GaussianFactor/scalar pairs.
* @param pruned Flag indicating if conditional has been pruned.
*/
HybridGaussianConditional(const DiscreteKeys &discreteParents,
const FactorValuePairs &pairs, bool pruned = false);
/// @}
/// @name Testable
/// @{
/// Test equality with base HybridFactor
bool equals(const HybridFactor &lf, double tol = 1e-9) const override;
/// Print utility
void print(
const std::string &s = "HybridGaussianConditional\n",
const KeyFormatter &formatter = DefaultKeyFormatter) const override;
/// @}
/// @name Standard API
/// @{
/// @brief Return the conditional Gaussian for the given discrete assignment.
GaussianConditional::shared_ptr choose(
const DiscreteValues &discreteValues) const;
/// @brief Syntactic sugar for choose.
GaussianConditional::shared_ptr operator()(
const DiscreteValues &discreteValues) const {
return choose(discreteValues);
}
/// Returns the total number of continuous components
size_t nrComponents() const;
/// Returns the continuous keys among the parents.
KeyVector continuousParents() const;
/**
* @brief Return log normalization constant in negative log space.
*
* The log normalization constant is the min of the individual
* log-normalization constants.
*
* @return double
*/
inline double negLogConstant() const override { return negLogConstant_; }
/**
* Create a likelihood factor for a hybrid Gaussian conditional,
* return a hybrid Gaussian factor on the parents.
*/
std::shared_ptr<HybridGaussianFactor> likelihood(
const VectorValues &given) const;
/// Get Conditionals DecisionTree (dynamic cast from factors)
/// @note Slow: avoid using in favor of factors(), which uses existing tree.
const Conditionals conditionals() const;
/**
* @brief Compute the logProbability of this hybrid Gaussian conditional.
*
* @param values Continuous values and discrete assignment.
* @return double
*/
double logProbability(const HybridValues &values) const override;
/// Calculate probability density for given `values`.
double evaluate(const HybridValues &values) const override;
/// Evaluate probability density, sugar.
double operator()(const HybridValues &values) const {
return evaluate(values);
}
/**
* @brief Prune the decision tree of Gaussian factors as per the discrete
* `discreteProbs`.
*
* @param discreteProbs A pruned set of probabilities for the discrete keys.
* @return Shared pointer to possibly a pruned HybridGaussianConditional
*/
HybridGaussianConditional::shared_ptr prune(
const DecisionTreeFactor &discreteProbs) const;
/// Return true if the conditional has already been pruned.
bool pruned() const { return pruned_; }
/// @}
private:
/// Helper struct for private constructor.
struct Helper;
/// Private constructor that uses helper struct above.
HybridGaussianConditional(const DiscreteKeys &discreteParents,
Helper &&helper, bool pruned = false);
/// Check whether `given` has values for all frontal keys.
bool allFrontalsGiven(const VectorValues &given) const;
#if GTSAM_ENABLE_BOOST_SERIALIZATION
/** Serialization function */
friend class boost::serialization::access;
template <class Archive>
void serialize(Archive &ar, const unsigned int /*version*/) {
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseFactor);
ar &BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional);
}
#endif
};
/// Return the DiscreteKey vector as a set.
std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &discreteKeys);
// traits
template <>
struct traits<HybridGaussianConditional>
: public Testable<HybridGaussianConditional> {};
} // namespace gtsam