Hybrid and Gaussian Mixture conditional docs and some refactor
parent
3f239c28be
commit
c3a92a4705
|
|
@ -42,7 +42,7 @@ GaussianMixtureConditional::conditionals() {
|
|||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
GaussianMixtureConditional GaussianMixtureConditional::FromConditionalList(
|
||||
GaussianMixtureConditional GaussianMixtureConditional::FromConditionals(
|
||||
const KeyVector &continuousFrontals, const KeyVector &continuousParents,
|
||||
const DiscreteKeys &discreteParents,
|
||||
const std::vector<GaussianConditional::shared_ptr> &conditionalsList) {
|
||||
|
|
@ -86,12 +86,12 @@ bool GaussianMixtureConditional::equals(const HybridFactor &lf,
|
|||
void GaussianMixtureConditional::print(const std::string &s,
|
||||
const KeyFormatter &formatter) const {
|
||||
std::cout << s << ": ";
|
||||
if (isContinuous_) std::cout << "Cont. ";
|
||||
if (isDiscrete_) std::cout << "Disc. ";
|
||||
if (isHybrid_) std::cout << "Hybr. ";
|
||||
if (isContinuous()) std::cout << "Cont. ";
|
||||
if (isDiscrete()) std::cout << "Disc. ";
|
||||
if (isHybrid()) std::cout << "Hybr. ";
|
||||
BaseConditional::print("", formatter);
|
||||
std::cout << "Discrete Keys = ";
|
||||
for (auto &dk : discreteKeys_) {
|
||||
for (auto &dk : discreteKeys()) {
|
||||
std::cout << "(" << formatter(dk.first) << ", " << dk.second << "), ";
|
||||
}
|
||||
std::cout << "\n";
|
||||
|
|
|
|||
|
|
@ -25,6 +25,14 @@
|
|||
#include <gtsam/linear/GaussianConditional.h>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
/**
|
||||
* @brief A conditional of gaussian mixtures indexed by discrete variables.
|
||||
*
|
||||
* Represents the conditional density P(X | M, Z) where X is a continuous random
|
||||
* variable, M is the discrete variable and Z is the set of measurements.
|
||||
*
|
||||
*/
|
||||
class GaussianMixtureConditional
|
||||
: public HybridFactor,
|
||||
public Conditional<HybridFactor, GaussianMixtureConditional> {
|
||||
|
|
@ -34,13 +42,28 @@ class GaussianMixtureConditional
|
|||
using BaseFactor = HybridFactor;
|
||||
using BaseConditional = Conditional<HybridFactor, GaussianMixtureConditional>;
|
||||
|
||||
/// Alias for DecisionTree of GaussianFactorGraphs
|
||||
using Sum = DecisionTree<Key, GaussianFactorGraph>;
|
||||
|
||||
/// typedef for Decision Tree of Gaussian Conditionals
|
||||
using Conditionals = DecisionTree<Key, GaussianConditional::shared_ptr>;
|
||||
|
||||
private:
|
||||
Conditionals conditionals_;
|
||||
|
||||
public:
|
||||
/**
|
||||
* @brief Construct a new Gaussian Mixture object
|
||||
* @brief Convert a DecisionTree of factors into a DT of Gaussian FGs.
|
||||
*/
|
||||
Sum asGaussianFactorGraphTree() const;
|
||||
|
||||
public:
|
||||
/// @name Constructors
|
||||
/// @{
|
||||
|
||||
/// Defaut constructor, mainly for serialization.
|
||||
GaussianMixtureConditional() = default;
|
||||
/**
|
||||
* @brief Construct a new GaussianMixtureConditional object
|
||||
*
|
||||
* @param continuousFrontals the continuous frontals.
|
||||
* @param continuousParents the continuous parents.
|
||||
|
|
@ -52,15 +75,6 @@ class GaussianMixtureConditional
|
|||
const DiscreteKeys &discreteParents,
|
||||
const Conditionals &conditionals);
|
||||
|
||||
using Sum = DecisionTree<Key, GaussianFactorGraph>;
|
||||
|
||||
const Conditionals &conditionals();
|
||||
|
||||
/**
|
||||
* @brief Combine Decision Trees
|
||||
*/
|
||||
Sum add(const Sum &sum) const;
|
||||
|
||||
/**
|
||||
* @brief Make a Gaussian Mixture from a list of Gaussian conditionals
|
||||
*
|
||||
|
|
@ -69,11 +83,15 @@ class GaussianMixtureConditional
|
|||
* @param discreteParents Discrete parents variables
|
||||
* @param conditionals List of conditionals
|
||||
*/
|
||||
static This FromConditionalList(
|
||||
static This FromConditionals(
|
||||
const KeyVector &continuousFrontals, const KeyVector &continuousParents,
|
||||
const DiscreteKeys &discreteParents,
|
||||
const std::vector<GaussianConditional::shared_ptr> &conditionals);
|
||||
|
||||
/// @}
|
||||
/// @name Testable
|
||||
/// @{
|
||||
|
||||
/// Test equality with base HybridFactor
|
||||
bool equals(const HybridFactor &lf, double tol = 1e-9) const override;
|
||||
|
||||
|
|
@ -82,11 +100,19 @@ class GaussianMixtureConditional
|
|||
const std::string &s = "GaussianMixtureConditional\n",
|
||||
const KeyFormatter &formatter = DefaultKeyFormatter) const override;
|
||||
|
||||
protected:
|
||||
/// @}
|
||||
|
||||
/// Getter for the underlying Conditionals DecisionTree
|
||||
const Conditionals &conditionals();
|
||||
|
||||
/**
|
||||
* @brief Convert a DecisionTree of factors into a DT of Gaussian FGs.
|
||||
* @brief Merge the Gaussian Factor Graphs in `this` and `sum` while
|
||||
* maintaining the decision tree structure.
|
||||
*
|
||||
* @param sum Decision Tree of Gaussian Factor Graphs
|
||||
* @return Sum
|
||||
*/
|
||||
Sum asGaussianFactorGraphTree() const;
|
||||
Sum add(const Sum &sum) const;
|
||||
};
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
|
|||
|
|
@ -22,6 +22,7 @@
|
|||
|
||||
namespace gtsam {
|
||||
|
||||
/* ************************************************************************ */
|
||||
HybridConditional::HybridConditional(const KeyVector &continuousFrontals,
|
||||
const DiscreteKeys &discreteFrontals,
|
||||
const KeyVector &continuousParents,
|
||||
|
|
@ -35,36 +36,40 @@ HybridConditional::HybridConditional(const KeyVector &continuousFrontals,
|
|||
{discreteParents.begin(), discreteParents.end()}),
|
||||
continuousFrontals.size() + discreteFrontals.size()) {}
|
||||
|
||||
/* ************************************************************************ */
|
||||
HybridConditional::HybridConditional(
|
||||
boost::shared_ptr<GaussianConditional> continuousConditional)
|
||||
: HybridConditional(continuousConditional->keys(), {},
|
||||
continuousConditional->nrFrontals()) {
|
||||
inner = continuousConditional;
|
||||
inner_ = continuousConditional;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
HybridConditional::HybridConditional(
|
||||
boost::shared_ptr<DiscreteConditional> discreteConditional)
|
||||
: HybridConditional({}, discreteConditional->discreteKeys(),
|
||||
discreteConditional->nrFrontals()) {
|
||||
inner = discreteConditional;
|
||||
inner_ = discreteConditional;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
HybridConditional::HybridConditional(
|
||||
boost::shared_ptr<GaussianMixtureConditional> gaussianMixture)
|
||||
: BaseFactor(KeyVector(gaussianMixture->keys().begin(),
|
||||
gaussianMixture->keys().begin() +
|
||||
gaussianMixture->nrContinuous),
|
||||
gaussianMixture->discreteKeys_),
|
||||
gaussianMixture->nrContinuous()),
|
||||
gaussianMixture->discreteKeys()),
|
||||
BaseConditional(gaussianMixture->nrFrontals()) {
|
||||
inner = gaussianMixture;
|
||||
inner_ = gaussianMixture;
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
void HybridConditional::print(const std::string &s,
|
||||
const KeyFormatter &formatter) const {
|
||||
std::cout << s;
|
||||
if (isContinuous_) std::cout << "Cont. ";
|
||||
if (isDiscrete_) std::cout << "Disc. ";
|
||||
if (isHybrid_) std::cout << "Hybr. ";
|
||||
if (isContinuous()) std::cout << "Cont. ";
|
||||
if (isDiscrete()) std::cout << "Disc. ";
|
||||
if (isHybrid()) std::cout << "Hybr. ";
|
||||
std::cout << "P(";
|
||||
size_t index = 0;
|
||||
const size_t N = keys().size();
|
||||
|
|
@ -85,9 +90,10 @@ void HybridConditional::print(const std::string &s,
|
|||
index++;
|
||||
}
|
||||
std::cout << ")\n";
|
||||
if (inner) inner->print("", formatter);
|
||||
if (inner_) inner_->print("", formatter);
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
bool HybridConditional::equals(const HybridFactor &other, double tol) const {
|
||||
return BaseFactor::equals(other, tol);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -54,8 +54,8 @@ class HybridFactorGraph;
|
|||
* having diamond inheritances, and neutralized the need to change other
|
||||
* components of GTSAM to make hybrid elimination work.
|
||||
*
|
||||
* A great reference to the type-erasure pattern is Edurado Madrid's CppCon
|
||||
* talk.
|
||||
* A great reference to the type-erasure pattern is Eduaado Madrid's CppCon
|
||||
* talk (https://www.youtube.com/watch?v=s082Qmd_nHs).
|
||||
*/
|
||||
class GTSAM_EXPORT HybridConditional
|
||||
: public HybridFactor,
|
||||
|
|
@ -70,7 +70,7 @@ class GTSAM_EXPORT HybridConditional
|
|||
|
||||
protected:
|
||||
// Type-erased pointer to the inner type
|
||||
boost::shared_ptr<Factor> inner;
|
||||
boost::shared_ptr<Factor> inner_;
|
||||
|
||||
public:
|
||||
/// @name Standard Constructors
|
||||
|
|
@ -79,35 +79,77 @@ class GTSAM_EXPORT HybridConditional
|
|||
/// Default constructor needed for serialization.
|
||||
HybridConditional() = default;
|
||||
|
||||
/**
|
||||
* @brief Construct a new Hybrid Conditional object
|
||||
*
|
||||
* @param continuousKeys Vector of keys for continuous variables.
|
||||
* @param discreteKeys Keys and cardinalities for discrete variables.
|
||||
* @param nFrontals The number of frontal variables in the conditional.
|
||||
*/
|
||||
HybridConditional(const KeyVector& continuousKeys,
|
||||
const DiscreteKeys& discreteKeys, size_t nFrontals)
|
||||
: BaseFactor(continuousKeys, discreteKeys), BaseConditional(nFrontals) {}
|
||||
|
||||
/**
|
||||
* @brief Construct a new Hybrid Conditional object
|
||||
*
|
||||
* @param continuousFrontals Vector of keys for continuous variables.
|
||||
* @param discreteFrontals Keys and cardinalities for discrete variables.
|
||||
* @param continuousParents Vector of keys for parent continuous variables.
|
||||
* @param discreteParents Keys and cardinalities for parent discrete
|
||||
* variables.
|
||||
*/
|
||||
HybridConditional(const KeyVector& continuousFrontals,
|
||||
const DiscreteKeys& discreteFrontals,
|
||||
const KeyVector& continuousParents,
|
||||
const DiscreteKeys& discreteParents);
|
||||
|
||||
/**
|
||||
* @brief Construct a new Hybrid Conditional object
|
||||
*
|
||||
* @param continuousConditional Conditional used to create the
|
||||
* HybridConditional.
|
||||
*/
|
||||
HybridConditional(
|
||||
boost::shared_ptr<GaussianConditional> continuousConditional);
|
||||
|
||||
/**
|
||||
* @brief Construct a new Hybrid Conditional object
|
||||
*
|
||||
* @param discreteConditional Conditional used to create the
|
||||
* HybridConditional.
|
||||
*/
|
||||
HybridConditional(boost::shared_ptr<DiscreteConditional> discreteConditional);
|
||||
|
||||
/**
|
||||
* @brief Construct a new Hybrid Conditional object
|
||||
*
|
||||
* @param gaussianMixture Gaussian Mixture Conditional used to create the
|
||||
* HybridConditional.
|
||||
*/
|
||||
HybridConditional(
|
||||
boost::shared_ptr<GaussianMixtureConditional> gaussianMixture);
|
||||
|
||||
/**
|
||||
* @brief Return HybridConditional as a GaussianMixtureConditional
|
||||
*
|
||||
* @return GaussianMixtureConditional::shared_ptr
|
||||
*/
|
||||
GaussianMixtureConditional::shared_ptr asMixture() {
|
||||
if (!isHybrid_) throw std::invalid_argument("Not a mixture");
|
||||
return boost::static_pointer_cast<GaussianMixtureConditional>(inner);
|
||||
if (!isHybrid()) throw std::invalid_argument("Not a mixture");
|
||||
return boost::static_pointer_cast<GaussianMixtureConditional>(inner_);
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Return conditional as a DiscreteConditional
|
||||
*
|
||||
* @return DiscreteConditional::shared_ptr
|
||||
*/
|
||||
DiscreteConditional::shared_ptr asDiscreteConditional() {
|
||||
if (!isDiscrete_) throw std::invalid_argument("Not a discrete conditional");
|
||||
return boost::static_pointer_cast<DiscreteConditional>(inner);
|
||||
if (!isDiscrete()) throw std::invalid_argument("Not a discrete conditional");
|
||||
return boost::static_pointer_cast<DiscreteConditional>(inner_);
|
||||
}
|
||||
|
||||
boost::shared_ptr<Factor> getInner() { return inner; }
|
||||
|
||||
/// @}
|
||||
/// @name Testable
|
||||
/// @{
|
||||
|
|
@ -122,11 +164,10 @@ class GTSAM_EXPORT HybridConditional
|
|||
|
||||
/// @}
|
||||
|
||||
friend std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr> //
|
||||
EliminateHybrid(const HybridFactorGraph& factors,
|
||||
const Ordering& frontalKeys);
|
||||
};
|
||||
// DiscreteConditional
|
||||
/// Get the type-erased pointer to the inner type
|
||||
boost::shared_ptr<Factor> inner() { return inner_; }
|
||||
|
||||
}; // DiscreteConditional
|
||||
|
||||
// traits
|
||||
template <>
|
||||
|
|
|
|||
Loading…
Reference in New Issue