Hybrid and Gaussian Mixture conditional docs and some refactor

release/4.3a0
Varun Agrawal 2022-05-27 15:12:19 -04:00
parent 3f239c28be
commit c3a92a4705
4 changed files with 116 additions and 43 deletions

View File

@ -42,7 +42,7 @@ GaussianMixtureConditional::conditionals() {
}
/* *******************************************************************************/
GaussianMixtureConditional GaussianMixtureConditional::FromConditionalList(
GaussianMixtureConditional GaussianMixtureConditional::FromConditionals(
const KeyVector &continuousFrontals, const KeyVector &continuousParents,
const DiscreteKeys &discreteParents,
const std::vector<GaussianConditional::shared_ptr> &conditionalsList) {
@ -86,12 +86,12 @@ bool GaussianMixtureConditional::equals(const HybridFactor &lf,
void GaussianMixtureConditional::print(const std::string &s,
const KeyFormatter &formatter) const {
std::cout << s << ": ";
if (isContinuous_) std::cout << "Cont. ";
if (isDiscrete_) std::cout << "Disc. ";
if (isHybrid_) std::cout << "Hybr. ";
if (isContinuous()) std::cout << "Cont. ";
if (isDiscrete()) std::cout << "Disc. ";
if (isHybrid()) std::cout << "Hybr. ";
BaseConditional::print("", formatter);
std::cout << "Discrete Keys = ";
for (auto &dk : discreteKeys_) {
for (auto &dk : discreteKeys()) {
std::cout << "(" << formatter(dk.first) << ", " << dk.second << "), ";
}
std::cout << "\n";

View File

@ -25,6 +25,14 @@
#include <gtsam/linear/GaussianConditional.h>
namespace gtsam {
/**
* @brief A conditional of gaussian mixtures indexed by discrete variables.
*
* Represents the conditional density P(X | M, Z) where X is a continuous random
* variable, M is the discrete variable and Z is the set of measurements.
*
*/
class GaussianMixtureConditional
: public HybridFactor,
public Conditional<HybridFactor, GaussianMixtureConditional> {
@ -34,13 +42,28 @@ class GaussianMixtureConditional
using BaseFactor = HybridFactor;
using BaseConditional = Conditional<HybridFactor, GaussianMixtureConditional>;
/// Alias for DecisionTree of GaussianFactorGraphs
using Sum = DecisionTree<Key, GaussianFactorGraph>;
/// typedef for Decision Tree of Gaussian Conditionals
using Conditionals = DecisionTree<Key, GaussianConditional::shared_ptr>;
private:
Conditionals conditionals_;
public:
/**
* @brief Construct a new Gaussian Mixture object
* @brief Convert a DecisionTree of factors into a DT of Gaussian FGs.
*/
Sum asGaussianFactorGraphTree() const;
public:
/// @name Constructors
/// @{
/// Defaut constructor, mainly for serialization.
GaussianMixtureConditional() = default;
/**
* @brief Construct a new GaussianMixtureConditional object
*
* @param continuousFrontals the continuous frontals.
* @param continuousParents the continuous parents.
@ -52,15 +75,6 @@ class GaussianMixtureConditional
const DiscreteKeys &discreteParents,
const Conditionals &conditionals);
using Sum = DecisionTree<Key, GaussianFactorGraph>;
const Conditionals &conditionals();
/**
* @brief Combine Decision Trees
*/
Sum add(const Sum &sum) const;
/**
* @brief Make a Gaussian Mixture from a list of Gaussian conditionals
*
@ -69,11 +83,15 @@ class GaussianMixtureConditional
* @param discreteParents Discrete parents variables
* @param conditionals List of conditionals
*/
static This FromConditionalList(
static This FromConditionals(
const KeyVector &continuousFrontals, const KeyVector &continuousParents,
const DiscreteKeys &discreteParents,
const std::vector<GaussianConditional::shared_ptr> &conditionals);
/// @}
/// @name Testable
/// @{
/// Test equality with base HybridFactor
bool equals(const HybridFactor &lf, double tol = 1e-9) const override;
@ -82,11 +100,19 @@ class GaussianMixtureConditional
const std::string &s = "GaussianMixtureConditional\n",
const KeyFormatter &formatter = DefaultKeyFormatter) const override;
protected:
/// @}
/// Getter for the underlying Conditionals DecisionTree
const Conditionals &conditionals();
/**
* @brief Convert a DecisionTree of factors into a DT of Gaussian FGs.
* @brief Merge the Gaussian Factor Graphs in `this` and `sum` while
* maintaining the decision tree structure.
*
* @param sum Decision Tree of Gaussian Factor Graphs
* @return Sum
*/
Sum asGaussianFactorGraphTree() const;
Sum add(const Sum &sum) const;
};
} // namespace gtsam

View File

@ -22,6 +22,7 @@
namespace gtsam {
/* ************************************************************************ */
HybridConditional::HybridConditional(const KeyVector &continuousFrontals,
const DiscreteKeys &discreteFrontals,
const KeyVector &continuousParents,
@ -35,36 +36,40 @@ HybridConditional::HybridConditional(const KeyVector &continuousFrontals,
{discreteParents.begin(), discreteParents.end()}),
continuousFrontals.size() + discreteFrontals.size()) {}
/* ************************************************************************ */
HybridConditional::HybridConditional(
boost::shared_ptr<GaussianConditional> continuousConditional)
: HybridConditional(continuousConditional->keys(), {},
continuousConditional->nrFrontals()) {
inner = continuousConditional;
inner_ = continuousConditional;
}
/* ************************************************************************ */
HybridConditional::HybridConditional(
boost::shared_ptr<DiscreteConditional> discreteConditional)
: HybridConditional({}, discreteConditional->discreteKeys(),
discreteConditional->nrFrontals()) {
inner = discreteConditional;
inner_ = discreteConditional;
}
/* ************************************************************************ */
HybridConditional::HybridConditional(
boost::shared_ptr<GaussianMixtureConditional> gaussianMixture)
: BaseFactor(KeyVector(gaussianMixture->keys().begin(),
gaussianMixture->keys().begin() +
gaussianMixture->nrContinuous),
gaussianMixture->discreteKeys_),
gaussianMixture->nrContinuous()),
gaussianMixture->discreteKeys()),
BaseConditional(gaussianMixture->nrFrontals()) {
inner = gaussianMixture;
inner_ = gaussianMixture;
}
/* ************************************************************************ */
void HybridConditional::print(const std::string &s,
const KeyFormatter &formatter) const {
std::cout << s;
if (isContinuous_) std::cout << "Cont. ";
if (isDiscrete_) std::cout << "Disc. ";
if (isHybrid_) std::cout << "Hybr. ";
if (isContinuous()) std::cout << "Cont. ";
if (isDiscrete()) std::cout << "Disc. ";
if (isHybrid()) std::cout << "Hybr. ";
std::cout << "P(";
size_t index = 0;
const size_t N = keys().size();
@ -85,9 +90,10 @@ void HybridConditional::print(const std::string &s,
index++;
}
std::cout << ")\n";
if (inner) inner->print("", formatter);
if (inner_) inner_->print("", formatter);
}
/* ************************************************************************ */
bool HybridConditional::equals(const HybridFactor &other, double tol) const {
return BaseFactor::equals(other, tol);
}

View File

@ -54,8 +54,8 @@ class HybridFactorGraph;
* having diamond inheritances, and neutralized the need to change other
* components of GTSAM to make hybrid elimination work.
*
* A great reference to the type-erasure pattern is Edurado Madrid's CppCon
* talk.
* A great reference to the type-erasure pattern is Eduaado Madrid's CppCon
* talk (https://www.youtube.com/watch?v=s082Qmd_nHs).
*/
class GTSAM_EXPORT HybridConditional
: public HybridFactor,
@ -70,7 +70,7 @@ class GTSAM_EXPORT HybridConditional
protected:
// Type-erased pointer to the inner type
boost::shared_ptr<Factor> inner;
boost::shared_ptr<Factor> inner_;
public:
/// @name Standard Constructors
@ -79,35 +79,77 @@ class GTSAM_EXPORT HybridConditional
/// Default constructor needed for serialization.
HybridConditional() = default;
/**
* @brief Construct a new Hybrid Conditional object
*
* @param continuousKeys Vector of keys for continuous variables.
* @param discreteKeys Keys and cardinalities for discrete variables.
* @param nFrontals The number of frontal variables in the conditional.
*/
HybridConditional(const KeyVector& continuousKeys,
const DiscreteKeys& discreteKeys, size_t nFrontals)
: BaseFactor(continuousKeys, discreteKeys), BaseConditional(nFrontals) {}
/**
* @brief Construct a new Hybrid Conditional object
*
* @param continuousFrontals Vector of keys for continuous variables.
* @param discreteFrontals Keys and cardinalities for discrete variables.
* @param continuousParents Vector of keys for parent continuous variables.
* @param discreteParents Keys and cardinalities for parent discrete
* variables.
*/
HybridConditional(const KeyVector& continuousFrontals,
const DiscreteKeys& discreteFrontals,
const KeyVector& continuousParents,
const DiscreteKeys& discreteParents);
/**
* @brief Construct a new Hybrid Conditional object
*
* @param continuousConditional Conditional used to create the
* HybridConditional.
*/
HybridConditional(
boost::shared_ptr<GaussianConditional> continuousConditional);
/**
* @brief Construct a new Hybrid Conditional object
*
* @param discreteConditional Conditional used to create the
* HybridConditional.
*/
HybridConditional(boost::shared_ptr<DiscreteConditional> discreteConditional);
/**
* @brief Construct a new Hybrid Conditional object
*
* @param gaussianMixture Gaussian Mixture Conditional used to create the
* HybridConditional.
*/
HybridConditional(
boost::shared_ptr<GaussianMixtureConditional> gaussianMixture);
/**
* @brief Return HybridConditional as a GaussianMixtureConditional
*
* @return GaussianMixtureConditional::shared_ptr
*/
GaussianMixtureConditional::shared_ptr asMixture() {
if (!isHybrid_) throw std::invalid_argument("Not a mixture");
return boost::static_pointer_cast<GaussianMixtureConditional>(inner);
if (!isHybrid()) throw std::invalid_argument("Not a mixture");
return boost::static_pointer_cast<GaussianMixtureConditional>(inner_);
}
/**
* @brief Return conditional as a DiscreteConditional
*
* @return DiscreteConditional::shared_ptr
*/
DiscreteConditional::shared_ptr asDiscreteConditional() {
if (!isDiscrete_) throw std::invalid_argument("Not a discrete conditional");
return boost::static_pointer_cast<DiscreteConditional>(inner);
if (!isDiscrete()) throw std::invalid_argument("Not a discrete conditional");
return boost::static_pointer_cast<DiscreteConditional>(inner_);
}
boost::shared_ptr<Factor> getInner() { return inner; }
/// @}
/// @name Testable
/// @{
@ -122,11 +164,10 @@ class GTSAM_EXPORT HybridConditional
/// @}
friend std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr> //
EliminateHybrid(const HybridFactorGraph& factors,
const Ordering& frontalKeys);
};
// DiscreteConditional
/// Get the type-erased pointer to the inner type
boost::shared_ptr<Factor> inner() { return inner_; }
}; // DiscreteConditional
// traits
template <>