Hybrid and Gaussian Mixture conditional docs and some refactor
parent
3f239c28be
commit
c3a92a4705
|
|
@ -42,7 +42,7 @@ GaussianMixtureConditional::conditionals() {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
GaussianMixtureConditional GaussianMixtureConditional::FromConditionalList(
|
GaussianMixtureConditional GaussianMixtureConditional::FromConditionals(
|
||||||
const KeyVector &continuousFrontals, const KeyVector &continuousParents,
|
const KeyVector &continuousFrontals, const KeyVector &continuousParents,
|
||||||
const DiscreteKeys &discreteParents,
|
const DiscreteKeys &discreteParents,
|
||||||
const std::vector<GaussianConditional::shared_ptr> &conditionalsList) {
|
const std::vector<GaussianConditional::shared_ptr> &conditionalsList) {
|
||||||
|
|
@ -86,12 +86,12 @@ bool GaussianMixtureConditional::equals(const HybridFactor &lf,
|
||||||
void GaussianMixtureConditional::print(const std::string &s,
|
void GaussianMixtureConditional::print(const std::string &s,
|
||||||
const KeyFormatter &formatter) const {
|
const KeyFormatter &formatter) const {
|
||||||
std::cout << s << ": ";
|
std::cout << s << ": ";
|
||||||
if (isContinuous_) std::cout << "Cont. ";
|
if (isContinuous()) std::cout << "Cont. ";
|
||||||
if (isDiscrete_) std::cout << "Disc. ";
|
if (isDiscrete()) std::cout << "Disc. ";
|
||||||
if (isHybrid_) std::cout << "Hybr. ";
|
if (isHybrid()) std::cout << "Hybr. ";
|
||||||
BaseConditional::print("", formatter);
|
BaseConditional::print("", formatter);
|
||||||
std::cout << "Discrete Keys = ";
|
std::cout << "Discrete Keys = ";
|
||||||
for (auto &dk : discreteKeys_) {
|
for (auto &dk : discreteKeys()) {
|
||||||
std::cout << "(" << formatter(dk.first) << ", " << dk.second << "), ";
|
std::cout << "(" << formatter(dk.first) << ", " << dk.second << "), ";
|
||||||
}
|
}
|
||||||
std::cout << "\n";
|
std::cout << "\n";
|
||||||
|
|
|
||||||
|
|
@ -25,6 +25,14 @@
|
||||||
#include <gtsam/linear/GaussianConditional.h>
|
#include <gtsam/linear/GaussianConditional.h>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief A conditional of gaussian mixtures indexed by discrete variables.
|
||||||
|
*
|
||||||
|
* Represents the conditional density P(X | M, Z) where X is a continuous random
|
||||||
|
* variable, M is the discrete variable and Z is the set of measurements.
|
||||||
|
*
|
||||||
|
*/
|
||||||
class GaussianMixtureConditional
|
class GaussianMixtureConditional
|
||||||
: public HybridFactor,
|
: public HybridFactor,
|
||||||
public Conditional<HybridFactor, GaussianMixtureConditional> {
|
public Conditional<HybridFactor, GaussianMixtureConditional> {
|
||||||
|
|
@ -34,13 +42,28 @@ class GaussianMixtureConditional
|
||||||
using BaseFactor = HybridFactor;
|
using BaseFactor = HybridFactor;
|
||||||
using BaseConditional = Conditional<HybridFactor, GaussianMixtureConditional>;
|
using BaseConditional = Conditional<HybridFactor, GaussianMixtureConditional>;
|
||||||
|
|
||||||
|
/// Alias for DecisionTree of GaussianFactorGraphs
|
||||||
|
using Sum = DecisionTree<Key, GaussianFactorGraph>;
|
||||||
|
|
||||||
|
/// typedef for Decision Tree of Gaussian Conditionals
|
||||||
using Conditionals = DecisionTree<Key, GaussianConditional::shared_ptr>;
|
using Conditionals = DecisionTree<Key, GaussianConditional::shared_ptr>;
|
||||||
|
|
||||||
|
private:
|
||||||
Conditionals conditionals_;
|
Conditionals conditionals_;
|
||||||
|
|
||||||
public:
|
|
||||||
/**
|
/**
|
||||||
* @brief Construct a new Gaussian Mixture object
|
* @brief Convert a DecisionTree of factors into a DT of Gaussian FGs.
|
||||||
|
*/
|
||||||
|
Sum asGaussianFactorGraphTree() const;
|
||||||
|
|
||||||
|
public:
|
||||||
|
/// @name Constructors
|
||||||
|
/// @{
|
||||||
|
|
||||||
|
/// Defaut constructor, mainly for serialization.
|
||||||
|
GaussianMixtureConditional() = default;
|
||||||
|
/**
|
||||||
|
* @brief Construct a new GaussianMixtureConditional object
|
||||||
*
|
*
|
||||||
* @param continuousFrontals the continuous frontals.
|
* @param continuousFrontals the continuous frontals.
|
||||||
* @param continuousParents the continuous parents.
|
* @param continuousParents the continuous parents.
|
||||||
|
|
@ -52,15 +75,6 @@ class GaussianMixtureConditional
|
||||||
const DiscreteKeys &discreteParents,
|
const DiscreteKeys &discreteParents,
|
||||||
const Conditionals &conditionals);
|
const Conditionals &conditionals);
|
||||||
|
|
||||||
using Sum = DecisionTree<Key, GaussianFactorGraph>;
|
|
||||||
|
|
||||||
const Conditionals &conditionals();
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @brief Combine Decision Trees
|
|
||||||
*/
|
|
||||||
Sum add(const Sum &sum) const;
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Make a Gaussian Mixture from a list of Gaussian conditionals
|
* @brief Make a Gaussian Mixture from a list of Gaussian conditionals
|
||||||
*
|
*
|
||||||
|
|
@ -69,11 +83,15 @@ class GaussianMixtureConditional
|
||||||
* @param discreteParents Discrete parents variables
|
* @param discreteParents Discrete parents variables
|
||||||
* @param conditionals List of conditionals
|
* @param conditionals List of conditionals
|
||||||
*/
|
*/
|
||||||
static This FromConditionalList(
|
static This FromConditionals(
|
||||||
const KeyVector &continuousFrontals, const KeyVector &continuousParents,
|
const KeyVector &continuousFrontals, const KeyVector &continuousParents,
|
||||||
const DiscreteKeys &discreteParents,
|
const DiscreteKeys &discreteParents,
|
||||||
const std::vector<GaussianConditional::shared_ptr> &conditionals);
|
const std::vector<GaussianConditional::shared_ptr> &conditionals);
|
||||||
|
|
||||||
|
/// @}
|
||||||
|
/// @name Testable
|
||||||
|
/// @{
|
||||||
|
|
||||||
/// Test equality with base HybridFactor
|
/// Test equality with base HybridFactor
|
||||||
bool equals(const HybridFactor &lf, double tol = 1e-9) const override;
|
bool equals(const HybridFactor &lf, double tol = 1e-9) const override;
|
||||||
|
|
||||||
|
|
@ -82,11 +100,19 @@ class GaussianMixtureConditional
|
||||||
const std::string &s = "GaussianMixtureConditional\n",
|
const std::string &s = "GaussianMixtureConditional\n",
|
||||||
const KeyFormatter &formatter = DefaultKeyFormatter) const override;
|
const KeyFormatter &formatter = DefaultKeyFormatter) const override;
|
||||||
|
|
||||||
protected:
|
/// @}
|
||||||
|
|
||||||
|
/// Getter for the underlying Conditionals DecisionTree
|
||||||
|
const Conditionals &conditionals();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Convert a DecisionTree of factors into a DT of Gaussian FGs.
|
* @brief Merge the Gaussian Factor Graphs in `this` and `sum` while
|
||||||
|
* maintaining the decision tree structure.
|
||||||
|
*
|
||||||
|
* @param sum Decision Tree of Gaussian Factor Graphs
|
||||||
|
* @return Sum
|
||||||
*/
|
*/
|
||||||
Sum asGaussianFactorGraphTree() const;
|
Sum add(const Sum &sum) const;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,7 @@
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
HybridConditional::HybridConditional(const KeyVector &continuousFrontals,
|
HybridConditional::HybridConditional(const KeyVector &continuousFrontals,
|
||||||
const DiscreteKeys &discreteFrontals,
|
const DiscreteKeys &discreteFrontals,
|
||||||
const KeyVector &continuousParents,
|
const KeyVector &continuousParents,
|
||||||
|
|
@ -35,36 +36,40 @@ HybridConditional::HybridConditional(const KeyVector &continuousFrontals,
|
||||||
{discreteParents.begin(), discreteParents.end()}),
|
{discreteParents.begin(), discreteParents.end()}),
|
||||||
continuousFrontals.size() + discreteFrontals.size()) {}
|
continuousFrontals.size() + discreteFrontals.size()) {}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
HybridConditional::HybridConditional(
|
HybridConditional::HybridConditional(
|
||||||
boost::shared_ptr<GaussianConditional> continuousConditional)
|
boost::shared_ptr<GaussianConditional> continuousConditional)
|
||||||
: HybridConditional(continuousConditional->keys(), {},
|
: HybridConditional(continuousConditional->keys(), {},
|
||||||
continuousConditional->nrFrontals()) {
|
continuousConditional->nrFrontals()) {
|
||||||
inner = continuousConditional;
|
inner_ = continuousConditional;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
HybridConditional::HybridConditional(
|
HybridConditional::HybridConditional(
|
||||||
boost::shared_ptr<DiscreteConditional> discreteConditional)
|
boost::shared_ptr<DiscreteConditional> discreteConditional)
|
||||||
: HybridConditional({}, discreteConditional->discreteKeys(),
|
: HybridConditional({}, discreteConditional->discreteKeys(),
|
||||||
discreteConditional->nrFrontals()) {
|
discreteConditional->nrFrontals()) {
|
||||||
inner = discreteConditional;
|
inner_ = discreteConditional;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
HybridConditional::HybridConditional(
|
HybridConditional::HybridConditional(
|
||||||
boost::shared_ptr<GaussianMixtureConditional> gaussianMixture)
|
boost::shared_ptr<GaussianMixtureConditional> gaussianMixture)
|
||||||
: BaseFactor(KeyVector(gaussianMixture->keys().begin(),
|
: BaseFactor(KeyVector(gaussianMixture->keys().begin(),
|
||||||
gaussianMixture->keys().begin() +
|
gaussianMixture->keys().begin() +
|
||||||
gaussianMixture->nrContinuous),
|
gaussianMixture->nrContinuous()),
|
||||||
gaussianMixture->discreteKeys_),
|
gaussianMixture->discreteKeys()),
|
||||||
BaseConditional(gaussianMixture->nrFrontals()) {
|
BaseConditional(gaussianMixture->nrFrontals()) {
|
||||||
inner = gaussianMixture;
|
inner_ = gaussianMixture;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
void HybridConditional::print(const std::string &s,
|
void HybridConditional::print(const std::string &s,
|
||||||
const KeyFormatter &formatter) const {
|
const KeyFormatter &formatter) const {
|
||||||
std::cout << s;
|
std::cout << s;
|
||||||
if (isContinuous_) std::cout << "Cont. ";
|
if (isContinuous()) std::cout << "Cont. ";
|
||||||
if (isDiscrete_) std::cout << "Disc. ";
|
if (isDiscrete()) std::cout << "Disc. ";
|
||||||
if (isHybrid_) std::cout << "Hybr. ";
|
if (isHybrid()) std::cout << "Hybr. ";
|
||||||
std::cout << "P(";
|
std::cout << "P(";
|
||||||
size_t index = 0;
|
size_t index = 0;
|
||||||
const size_t N = keys().size();
|
const size_t N = keys().size();
|
||||||
|
|
@ -85,9 +90,10 @@ void HybridConditional::print(const std::string &s,
|
||||||
index++;
|
index++;
|
||||||
}
|
}
|
||||||
std::cout << ")\n";
|
std::cout << ")\n";
|
||||||
if (inner) inner->print("", formatter);
|
if (inner_) inner_->print("", formatter);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* ************************************************************************ */
|
||||||
bool HybridConditional::equals(const HybridFactor &other, double tol) const {
|
bool HybridConditional::equals(const HybridFactor &other, double tol) const {
|
||||||
return BaseFactor::equals(other, tol);
|
return BaseFactor::equals(other, tol);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -54,8 +54,8 @@ class HybridFactorGraph;
|
||||||
* having diamond inheritances, and neutralized the need to change other
|
* having diamond inheritances, and neutralized the need to change other
|
||||||
* components of GTSAM to make hybrid elimination work.
|
* components of GTSAM to make hybrid elimination work.
|
||||||
*
|
*
|
||||||
* A great reference to the type-erasure pattern is Edurado Madrid's CppCon
|
* A great reference to the type-erasure pattern is Eduaado Madrid's CppCon
|
||||||
* talk.
|
* talk (https://www.youtube.com/watch?v=s082Qmd_nHs).
|
||||||
*/
|
*/
|
||||||
class GTSAM_EXPORT HybridConditional
|
class GTSAM_EXPORT HybridConditional
|
||||||
: public HybridFactor,
|
: public HybridFactor,
|
||||||
|
|
@ -70,7 +70,7 @@ class GTSAM_EXPORT HybridConditional
|
||||||
|
|
||||||
protected:
|
protected:
|
||||||
// Type-erased pointer to the inner type
|
// Type-erased pointer to the inner type
|
||||||
boost::shared_ptr<Factor> inner;
|
boost::shared_ptr<Factor> inner_;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
/// @name Standard Constructors
|
/// @name Standard Constructors
|
||||||
|
|
@ -79,35 +79,77 @@ class GTSAM_EXPORT HybridConditional
|
||||||
/// Default constructor needed for serialization.
|
/// Default constructor needed for serialization.
|
||||||
HybridConditional() = default;
|
HybridConditional() = default;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Construct a new Hybrid Conditional object
|
||||||
|
*
|
||||||
|
* @param continuousKeys Vector of keys for continuous variables.
|
||||||
|
* @param discreteKeys Keys and cardinalities for discrete variables.
|
||||||
|
* @param nFrontals The number of frontal variables in the conditional.
|
||||||
|
*/
|
||||||
HybridConditional(const KeyVector& continuousKeys,
|
HybridConditional(const KeyVector& continuousKeys,
|
||||||
const DiscreteKeys& discreteKeys, size_t nFrontals)
|
const DiscreteKeys& discreteKeys, size_t nFrontals)
|
||||||
: BaseFactor(continuousKeys, discreteKeys), BaseConditional(nFrontals) {}
|
: BaseFactor(continuousKeys, discreteKeys), BaseConditional(nFrontals) {}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Construct a new Hybrid Conditional object
|
||||||
|
*
|
||||||
|
* @param continuousFrontals Vector of keys for continuous variables.
|
||||||
|
* @param discreteFrontals Keys and cardinalities for discrete variables.
|
||||||
|
* @param continuousParents Vector of keys for parent continuous variables.
|
||||||
|
* @param discreteParents Keys and cardinalities for parent discrete
|
||||||
|
* variables.
|
||||||
|
*/
|
||||||
HybridConditional(const KeyVector& continuousFrontals,
|
HybridConditional(const KeyVector& continuousFrontals,
|
||||||
const DiscreteKeys& discreteFrontals,
|
const DiscreteKeys& discreteFrontals,
|
||||||
const KeyVector& continuousParents,
|
const KeyVector& continuousParents,
|
||||||
const DiscreteKeys& discreteParents);
|
const DiscreteKeys& discreteParents);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Construct a new Hybrid Conditional object
|
||||||
|
*
|
||||||
|
* @param continuousConditional Conditional used to create the
|
||||||
|
* HybridConditional.
|
||||||
|
*/
|
||||||
HybridConditional(
|
HybridConditional(
|
||||||
boost::shared_ptr<GaussianConditional> continuousConditional);
|
boost::shared_ptr<GaussianConditional> continuousConditional);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Construct a new Hybrid Conditional object
|
||||||
|
*
|
||||||
|
* @param discreteConditional Conditional used to create the
|
||||||
|
* HybridConditional.
|
||||||
|
*/
|
||||||
HybridConditional(boost::shared_ptr<DiscreteConditional> discreteConditional);
|
HybridConditional(boost::shared_ptr<DiscreteConditional> discreteConditional);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Construct a new Hybrid Conditional object
|
||||||
|
*
|
||||||
|
* @param gaussianMixture Gaussian Mixture Conditional used to create the
|
||||||
|
* HybridConditional.
|
||||||
|
*/
|
||||||
HybridConditional(
|
HybridConditional(
|
||||||
boost::shared_ptr<GaussianMixtureConditional> gaussianMixture);
|
boost::shared_ptr<GaussianMixtureConditional> gaussianMixture);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Return HybridConditional as a GaussianMixtureConditional
|
||||||
|
*
|
||||||
|
* @return GaussianMixtureConditional::shared_ptr
|
||||||
|
*/
|
||||||
GaussianMixtureConditional::shared_ptr asMixture() {
|
GaussianMixtureConditional::shared_ptr asMixture() {
|
||||||
if (!isHybrid_) throw std::invalid_argument("Not a mixture");
|
if (!isHybrid()) throw std::invalid_argument("Not a mixture");
|
||||||
return boost::static_pointer_cast<GaussianMixtureConditional>(inner);
|
return boost::static_pointer_cast<GaussianMixtureConditional>(inner_);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Return conditional as a DiscreteConditional
|
||||||
|
*
|
||||||
|
* @return DiscreteConditional::shared_ptr
|
||||||
|
*/
|
||||||
DiscreteConditional::shared_ptr asDiscreteConditional() {
|
DiscreteConditional::shared_ptr asDiscreteConditional() {
|
||||||
if (!isDiscrete_) throw std::invalid_argument("Not a discrete conditional");
|
if (!isDiscrete()) throw std::invalid_argument("Not a discrete conditional");
|
||||||
return boost::static_pointer_cast<DiscreteConditional>(inner);
|
return boost::static_pointer_cast<DiscreteConditional>(inner_);
|
||||||
}
|
}
|
||||||
|
|
||||||
boost::shared_ptr<Factor> getInner() { return inner; }
|
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
/// @name Testable
|
/// @name Testable
|
||||||
/// @{
|
/// @{
|
||||||
|
|
@ -122,11 +164,10 @@ class GTSAM_EXPORT HybridConditional
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
friend std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr> //
|
/// Get the type-erased pointer to the inner type
|
||||||
EliminateHybrid(const HybridFactorGraph& factors,
|
boost::shared_ptr<Factor> inner() { return inner_; }
|
||||||
const Ordering& frontalKeys);
|
|
||||||
};
|
}; // DiscreteConditional
|
||||||
// DiscreteConditional
|
|
||||||
|
|
||||||
// traits
|
// traits
|
||||||
template <>
|
template <>
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue