238 lines
7.9 KiB
C++
238 lines
7.9 KiB
C++
/* ----------------------------------------------------------------------------
|
|
|
|
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
|
* Atlanta, Georgia 30332-0415
|
|
* All Rights Reserved
|
|
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
|
|
|
* See LICENSE for the license information
|
|
|
|
* -------------------------------------------------------------------------- */
|
|
|
|
/**
|
|
* @file HybridConditional.h
|
|
* @date Mar 11, 2022
|
|
* @author Fan Jiang
|
|
*/
|
|
|
|
#pragma once
|
|
|
|
#include <gtsam/discrete/DiscreteConditional.h>
|
|
#include <gtsam/hybrid/HybridGaussianConditional.h>
|
|
#include <gtsam/hybrid/HybridFactor.h>
|
|
#include <gtsam/hybrid/HybridGaussianFactorGraph.h>
|
|
#include <gtsam/inference/Conditional.h>
|
|
#include <gtsam/inference/Key.h>
|
|
#include <gtsam/linear/GaussianConditional.h>
|
|
|
|
#include <memory>
|
|
#include <stdexcept>
|
|
#include <string>
|
|
#include <typeinfo>
|
|
#include <vector>
|
|
|
|
namespace gtsam {
|
|
|
|
/**
|
|
* Hybrid Conditional Density
|
|
*
|
|
* As a type-erased variant of:
|
|
* - DiscreteConditional
|
|
* - GaussianConditional
|
|
* - HybridGaussianConditional
|
|
*
|
|
* The reason why this is important is that `Conditional<T>` is a CRTP class.
|
|
* CRTP is static polymorphism such that all CRTP classes, while bearing the
|
|
* same name, are different classes not sharing a vtable. This prevents them
|
|
* from being contained in any container, and thus it is impossible to
|
|
* dynamically cast between them. A better option, as illustrated here, is
|
|
* treating them as an implementation detail - such that the hybrid mechanism
|
|
* does not know what is inside the HybridConditional. This prevents us from
|
|
* having diamond inheritances, and neutralized the need to change other
|
|
* components of GTSAM to make hybrid elimination work.
|
|
*
|
|
* A great reference to the type-erasure pattern is Eduardo Madrid's CppCon
|
|
* talk (https://www.youtube.com/watch?v=s082Qmd_nHs).
|
|
*
|
|
* @ingroup hybrid
|
|
*/
|
|
class GTSAM_EXPORT HybridConditional
|
|
: public HybridFactor,
|
|
public Conditional<HybridFactor, HybridConditional> {
|
|
public:
|
|
// typedefs needed to play nice with gtsam
|
|
typedef HybridConditional This; ///< Typedef to this class
|
|
typedef std::shared_ptr<This> shared_ptr; ///< shared_ptr to this class
|
|
typedef HybridFactor BaseFactor; ///< Typedef to our factor base class
|
|
typedef Conditional<BaseFactor, This>
|
|
BaseConditional; ///< Typedef to our conditional base class
|
|
|
|
protected:
|
|
/// Type-erased pointer to the inner type
|
|
std::shared_ptr<Factor> inner_;
|
|
|
|
public:
|
|
/// @name Standard Constructors
|
|
/// @{
|
|
|
|
/// Default constructor needed for serialization.
|
|
HybridConditional() = default;
|
|
|
|
/**
|
|
* @brief Construct a new Hybrid Conditional object
|
|
*
|
|
* @param continuousKeys Vector of keys for continuous variables.
|
|
* @param discreteKeys Keys and cardinalities for discrete variables.
|
|
* @param nFrontals The number of frontal variables in the conditional.
|
|
*/
|
|
HybridConditional(const KeyVector& continuousKeys,
|
|
const DiscreteKeys& discreteKeys, size_t nFrontals)
|
|
: BaseFactor(continuousKeys, discreteKeys), BaseConditional(nFrontals) {}
|
|
|
|
/**
|
|
* @brief Construct a new Hybrid Conditional object
|
|
*
|
|
* @param continuousFrontals Vector of keys for continuous variables.
|
|
* @param discreteFrontals Keys and cardinalities for discrete variables.
|
|
* @param continuousParents Vector of keys for parent continuous variables.
|
|
* @param discreteParents Keys and cardinalities for parent discrete
|
|
* variables.
|
|
*/
|
|
HybridConditional(const KeyVector& continuousFrontals,
|
|
const DiscreteKeys& discreteFrontals,
|
|
const KeyVector& continuousParents,
|
|
const DiscreteKeys& discreteParents);
|
|
|
|
/**
|
|
* @brief Construct a new Hybrid Conditional object
|
|
*
|
|
* @param continuousConditional Conditional used to create the
|
|
* HybridConditional.
|
|
*/
|
|
HybridConditional(
|
|
const std::shared_ptr<GaussianConditional>& continuousConditional);
|
|
|
|
/**
|
|
* @brief Construct a new Hybrid Conditional object
|
|
*
|
|
* @param discreteConditional Conditional used to create the
|
|
* HybridConditional.
|
|
*/
|
|
HybridConditional(
|
|
const std::shared_ptr<DiscreteConditional>& discreteConditional);
|
|
|
|
/**
|
|
* @brief Construct a new Hybrid Conditional object
|
|
*
|
|
* @param gaussianMixture Gaussian Mixture Conditional used to create the
|
|
* HybridConditional.
|
|
*/
|
|
HybridConditional(const std::shared_ptr<HybridGaussianConditional>& gaussianMixture);
|
|
|
|
/// @}
|
|
/// @name Testable
|
|
/// @{
|
|
|
|
/// GTSAM-style print
|
|
void print(
|
|
const std::string& s = "Hybrid Conditional: ",
|
|
const KeyFormatter& formatter = DefaultKeyFormatter) const override;
|
|
|
|
/// GTSAM-style equals
|
|
bool equals(const HybridFactor& other, double tol = 1e-9) const override;
|
|
|
|
/// @}
|
|
/// @name Standard Interface
|
|
/// @{
|
|
|
|
/**
|
|
* @brief Return HybridConditional as a HybridGaussianConditional
|
|
* @return nullptr if not a mixture
|
|
* @return HybridGaussianConditional::shared_ptr otherwise
|
|
*/
|
|
HybridGaussianConditional::shared_ptr asMixture() const {
|
|
return std::dynamic_pointer_cast<HybridGaussianConditional>(inner_);
|
|
}
|
|
|
|
/**
|
|
* @brief Return HybridConditional as a GaussianConditional
|
|
* @return nullptr if not a GaussianConditional
|
|
* @return GaussianConditional::shared_ptr otherwise
|
|
*/
|
|
GaussianConditional::shared_ptr asGaussian() const {
|
|
return std::dynamic_pointer_cast<GaussianConditional>(inner_);
|
|
}
|
|
|
|
/**
|
|
* @brief Return conditional as a DiscreteConditional
|
|
* @return nullptr if not a DiscreteConditional
|
|
* @return DiscreteConditional::shared_ptr
|
|
*/
|
|
DiscreteConditional::shared_ptr asDiscrete() const {
|
|
return std::dynamic_pointer_cast<DiscreteConditional>(inner_);
|
|
}
|
|
|
|
/// Get the type-erased pointer to the inner type
|
|
std::shared_ptr<Factor> inner() const { return inner_; }
|
|
|
|
/// Return the error of the underlying conditional.
|
|
double error(const HybridValues& values) const override;
|
|
|
|
/// Return the log-probability (or density) of the underlying conditional.
|
|
double logProbability(const HybridValues& values) const override;
|
|
|
|
/**
|
|
* Return the log normalization constant.
|
|
* Note this is 0.0 for discrete and hybrid conditionals, but depends
|
|
* on the continuous parameters for Gaussian conditionals.
|
|
*/
|
|
double logNormalizationConstant() const override;
|
|
|
|
/// Return the probability (or density) of the underlying conditional.
|
|
double evaluate(const HybridValues& values) const override;
|
|
|
|
/// Check if VectorValues `measurements` contains all frontal keys.
|
|
bool frontalsIn(const VectorValues& measurements) const {
|
|
for (Key key : frontals()) {
|
|
if (!measurements.exists(key)) {
|
|
return false;
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
/// @}
|
|
|
|
private:
|
|
#ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
|
|
/** Serialization function */
|
|
friend class boost::serialization::access;
|
|
template <class Archive>
|
|
void serialize(Archive& ar, const unsigned int /*version*/) {
|
|
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseFactor);
|
|
ar& BOOST_SERIALIZATION_BASE_OBJECT_NVP(BaseConditional);
|
|
ar& BOOST_SERIALIZATION_NVP(inner_);
|
|
|
|
// register the various casts based on the type of inner_
|
|
// https://www.boost.org/doc/libs/1_80_0/libs/serialization/doc/serialization.html#runtimecasting
|
|
if (isDiscrete()) {
|
|
boost::serialization::void_cast_register<DiscreteConditional, Factor>(
|
|
static_cast<DiscreteConditional*>(NULL), static_cast<Factor*>(NULL));
|
|
} else if (isContinuous()) {
|
|
boost::serialization::void_cast_register<GaussianConditional, Factor>(
|
|
static_cast<GaussianConditional*>(NULL), static_cast<Factor*>(NULL));
|
|
} else {
|
|
boost::serialization::void_cast_register<HybridGaussianConditional, Factor>(
|
|
static_cast<HybridGaussianConditional*>(NULL), static_cast<Factor*>(NULL));
|
|
}
|
|
}
|
|
#endif
|
|
|
|
}; // HybridConditional
|
|
|
|
// traits
|
|
template <>
|
|
struct traits<HybridConditional> : public Testable<HybridConditional> {};
|
|
|
|
} // namespace gtsam
|