Hybrid and Gaussian Mixture conditional docs and some refactor

release/4.3a0
Varun Agrawal 2022-05-27 15:12:19 -04:00
parent 3f239c28be
commit c3a92a4705
4 changed files with 116 additions and 43 deletions

View File

@ -42,7 +42,7 @@ GaussianMixtureConditional::conditionals() {
} }
/* *******************************************************************************/ /* *******************************************************************************/
GaussianMixtureConditional GaussianMixtureConditional::FromConditionalList( GaussianMixtureConditional GaussianMixtureConditional::FromConditionals(
const KeyVector &continuousFrontals, const KeyVector &continuousParents, const KeyVector &continuousFrontals, const KeyVector &continuousParents,
const DiscreteKeys &discreteParents, const DiscreteKeys &discreteParents,
const std::vector<GaussianConditional::shared_ptr> &conditionalsList) { const std::vector<GaussianConditional::shared_ptr> &conditionalsList) {
@ -86,12 +86,12 @@ bool GaussianMixtureConditional::equals(const HybridFactor &lf,
void GaussianMixtureConditional::print(const std::string &s, void GaussianMixtureConditional::print(const std::string &s,
const KeyFormatter &formatter) const { const KeyFormatter &formatter) const {
std::cout << s << ": "; std::cout << s << ": ";
if (isContinuous_) std::cout << "Cont. "; if (isContinuous()) std::cout << "Cont. ";
if (isDiscrete_) std::cout << "Disc. "; if (isDiscrete()) std::cout << "Disc. ";
if (isHybrid_) std::cout << "Hybr. "; if (isHybrid()) std::cout << "Hybr. ";
BaseConditional::print("", formatter); BaseConditional::print("", formatter);
std::cout << "Discrete Keys = "; std::cout << "Discrete Keys = ";
for (auto &dk : discreteKeys_) { for (auto &dk : discreteKeys()) {
std::cout << "(" << formatter(dk.first) << ", " << dk.second << "), "; std::cout << "(" << formatter(dk.first) << ", " << dk.second << "), ";
} }
std::cout << "\n"; std::cout << "\n";

View File

@ -25,6 +25,14 @@
#include <gtsam/linear/GaussianConditional.h> #include <gtsam/linear/GaussianConditional.h>
namespace gtsam { namespace gtsam {
/**
* @brief A conditional of gaussian mixtures indexed by discrete variables.
*
* Represents the conditional density P(X | M, Z) where X is a continuous random
* variable, M is the discrete variable and Z is the set of measurements.
*
*/
class GaussianMixtureConditional class GaussianMixtureConditional
: public HybridFactor, : public HybridFactor,
public Conditional<HybridFactor, GaussianMixtureConditional> { public Conditional<HybridFactor, GaussianMixtureConditional> {
@ -34,13 +42,28 @@ class GaussianMixtureConditional
using BaseFactor = HybridFactor; using BaseFactor = HybridFactor;
using BaseConditional = Conditional<HybridFactor, GaussianMixtureConditional>; using BaseConditional = Conditional<HybridFactor, GaussianMixtureConditional>;
/// Alias for DecisionTree of GaussianFactorGraphs
using Sum = DecisionTree<Key, GaussianFactorGraph>;
/// typedef for Decision Tree of Gaussian Conditionals
using Conditionals = DecisionTree<Key, GaussianConditional::shared_ptr>; using Conditionals = DecisionTree<Key, GaussianConditional::shared_ptr>;
private:
Conditionals conditionals_; Conditionals conditionals_;
public:
/** /**
* @brief Construct a new Gaussian Mixture object * @brief Convert a DecisionTree of factors into a DT of Gaussian FGs.
*/
Sum asGaussianFactorGraphTree() const;
public:
/// @name Constructors
/// @{
/// Defaut constructor, mainly for serialization.
GaussianMixtureConditional() = default;
/**
* @brief Construct a new GaussianMixtureConditional object
* *
* @param continuousFrontals the continuous frontals. * @param continuousFrontals the continuous frontals.
* @param continuousParents the continuous parents. * @param continuousParents the continuous parents.
@ -52,15 +75,6 @@ class GaussianMixtureConditional
const DiscreteKeys &discreteParents, const DiscreteKeys &discreteParents,
const Conditionals &conditionals); const Conditionals &conditionals);
using Sum = DecisionTree<Key, GaussianFactorGraph>;
const Conditionals &conditionals();
/**
* @brief Combine Decision Trees
*/
Sum add(const Sum &sum) const;
/** /**
* @brief Make a Gaussian Mixture from a list of Gaussian conditionals * @brief Make a Gaussian Mixture from a list of Gaussian conditionals
* *
@ -69,11 +83,15 @@ class GaussianMixtureConditional
* @param discreteParents Discrete parents variables * @param discreteParents Discrete parents variables
* @param conditionals List of conditionals * @param conditionals List of conditionals
*/ */
static This FromConditionalList( static This FromConditionals(
const KeyVector &continuousFrontals, const KeyVector &continuousParents, const KeyVector &continuousFrontals, const KeyVector &continuousParents,
const DiscreteKeys &discreteParents, const DiscreteKeys &discreteParents,
const std::vector<GaussianConditional::shared_ptr> &conditionals); const std::vector<GaussianConditional::shared_ptr> &conditionals);
/// @}
/// @name Testable
/// @{
/// Test equality with base HybridFactor /// Test equality with base HybridFactor
bool equals(const HybridFactor &lf, double tol = 1e-9) const override; bool equals(const HybridFactor &lf, double tol = 1e-9) const override;
@ -82,11 +100,19 @@ class GaussianMixtureConditional
const std::string &s = "GaussianMixtureConditional\n", const std::string &s = "GaussianMixtureConditional\n",
const KeyFormatter &formatter = DefaultKeyFormatter) const override; const KeyFormatter &formatter = DefaultKeyFormatter) const override;
protected: /// @}
/// Getter for the underlying Conditionals DecisionTree
const Conditionals &conditionals();
/** /**
* @brief Convert a DecisionTree of factors into a DT of Gaussian FGs. * @brief Merge the Gaussian Factor Graphs in `this` and `sum` while
* maintaining the decision tree structure.
*
* @param sum Decision Tree of Gaussian Factor Graphs
* @return Sum
*/ */
Sum asGaussianFactorGraphTree() const; Sum add(const Sum &sum) const;
}; };
} // namespace gtsam } // namespace gtsam

View File

@ -22,6 +22,7 @@
namespace gtsam { namespace gtsam {
/* ************************************************************************ */
HybridConditional::HybridConditional(const KeyVector &continuousFrontals, HybridConditional::HybridConditional(const KeyVector &continuousFrontals,
const DiscreteKeys &discreteFrontals, const DiscreteKeys &discreteFrontals,
const KeyVector &continuousParents, const KeyVector &continuousParents,
@ -35,36 +36,40 @@ HybridConditional::HybridConditional(const KeyVector &continuousFrontals,
{discreteParents.begin(), discreteParents.end()}), {discreteParents.begin(), discreteParents.end()}),
continuousFrontals.size() + discreteFrontals.size()) {} continuousFrontals.size() + discreteFrontals.size()) {}
/* ************************************************************************ */
HybridConditional::HybridConditional( HybridConditional::HybridConditional(
boost::shared_ptr<GaussianConditional> continuousConditional) boost::shared_ptr<GaussianConditional> continuousConditional)
: HybridConditional(continuousConditional->keys(), {}, : HybridConditional(continuousConditional->keys(), {},
continuousConditional->nrFrontals()) { continuousConditional->nrFrontals()) {
inner = continuousConditional; inner_ = continuousConditional;
} }
/* ************************************************************************ */
HybridConditional::HybridConditional( HybridConditional::HybridConditional(
boost::shared_ptr<DiscreteConditional> discreteConditional) boost::shared_ptr<DiscreteConditional> discreteConditional)
: HybridConditional({}, discreteConditional->discreteKeys(), : HybridConditional({}, discreteConditional->discreteKeys(),
discreteConditional->nrFrontals()) { discreteConditional->nrFrontals()) {
inner = discreteConditional; inner_ = discreteConditional;
} }
/* ************************************************************************ */
HybridConditional::HybridConditional( HybridConditional::HybridConditional(
boost::shared_ptr<GaussianMixtureConditional> gaussianMixture) boost::shared_ptr<GaussianMixtureConditional> gaussianMixture)
: BaseFactor(KeyVector(gaussianMixture->keys().begin(), : BaseFactor(KeyVector(gaussianMixture->keys().begin(),
gaussianMixture->keys().begin() + gaussianMixture->keys().begin() +
gaussianMixture->nrContinuous), gaussianMixture->nrContinuous()),
gaussianMixture->discreteKeys_), gaussianMixture->discreteKeys()),
BaseConditional(gaussianMixture->nrFrontals()) { BaseConditional(gaussianMixture->nrFrontals()) {
inner = gaussianMixture; inner_ = gaussianMixture;
} }
/* ************************************************************************ */
void HybridConditional::print(const std::string &s, void HybridConditional::print(const std::string &s,
const KeyFormatter &formatter) const { const KeyFormatter &formatter) const {
std::cout << s; std::cout << s;
if (isContinuous_) std::cout << "Cont. "; if (isContinuous()) std::cout << "Cont. ";
if (isDiscrete_) std::cout << "Disc. "; if (isDiscrete()) std::cout << "Disc. ";
if (isHybrid_) std::cout << "Hybr. "; if (isHybrid()) std::cout << "Hybr. ";
std::cout << "P("; std::cout << "P(";
size_t index = 0; size_t index = 0;
const size_t N = keys().size(); const size_t N = keys().size();
@ -85,9 +90,10 @@ void HybridConditional::print(const std::string &s,
index++; index++;
} }
std::cout << ")\n"; std::cout << ")\n";
if (inner) inner->print("", formatter); if (inner_) inner_->print("", formatter);
} }
/* ************************************************************************ */
bool HybridConditional::equals(const HybridFactor &other, double tol) const { bool HybridConditional::equals(const HybridFactor &other, double tol) const {
return BaseFactor::equals(other, tol); return BaseFactor::equals(other, tol);
} }

View File

@ -54,8 +54,8 @@ class HybridFactorGraph;
* having diamond inheritances, and neutralized the need to change other * having diamond inheritances, and neutralized the need to change other
* components of GTSAM to make hybrid elimination work. * components of GTSAM to make hybrid elimination work.
* *
* A great reference to the type-erasure pattern is Edurado Madrid's CppCon * A great reference to the type-erasure pattern is Eduaado Madrid's CppCon
* talk. * talk (https://www.youtube.com/watch?v=s082Qmd_nHs).
*/ */
class GTSAM_EXPORT HybridConditional class GTSAM_EXPORT HybridConditional
: public HybridFactor, : public HybridFactor,
@ -70,7 +70,7 @@ class GTSAM_EXPORT HybridConditional
protected: protected:
// Type-erased pointer to the inner type // Type-erased pointer to the inner type
boost::shared_ptr<Factor> inner; boost::shared_ptr<Factor> inner_;
public: public:
/// @name Standard Constructors /// @name Standard Constructors
@ -79,35 +79,77 @@ class GTSAM_EXPORT HybridConditional
/// Default constructor needed for serialization. /// Default constructor needed for serialization.
HybridConditional() = default; HybridConditional() = default;
/**
* @brief Construct a new Hybrid Conditional object
*
* @param continuousKeys Vector of keys for continuous variables.
* @param discreteKeys Keys and cardinalities for discrete variables.
* @param nFrontals The number of frontal variables in the conditional.
*/
HybridConditional(const KeyVector& continuousKeys, HybridConditional(const KeyVector& continuousKeys,
const DiscreteKeys& discreteKeys, size_t nFrontals) const DiscreteKeys& discreteKeys, size_t nFrontals)
: BaseFactor(continuousKeys, discreteKeys), BaseConditional(nFrontals) {} : BaseFactor(continuousKeys, discreteKeys), BaseConditional(nFrontals) {}
/**
* @brief Construct a new Hybrid Conditional object
*
* @param continuousFrontals Vector of keys for continuous variables.
* @param discreteFrontals Keys and cardinalities for discrete variables.
* @param continuousParents Vector of keys for parent continuous variables.
* @param discreteParents Keys and cardinalities for parent discrete
* variables.
*/
HybridConditional(const KeyVector& continuousFrontals, HybridConditional(const KeyVector& continuousFrontals,
const DiscreteKeys& discreteFrontals, const DiscreteKeys& discreteFrontals,
const KeyVector& continuousParents, const KeyVector& continuousParents,
const DiscreteKeys& discreteParents); const DiscreteKeys& discreteParents);
/**
* @brief Construct a new Hybrid Conditional object
*
* @param continuousConditional Conditional used to create the
* HybridConditional.
*/
HybridConditional( HybridConditional(
boost::shared_ptr<GaussianConditional> continuousConditional); boost::shared_ptr<GaussianConditional> continuousConditional);
/**
* @brief Construct a new Hybrid Conditional object
*
* @param discreteConditional Conditional used to create the
* HybridConditional.
*/
HybridConditional(boost::shared_ptr<DiscreteConditional> discreteConditional); HybridConditional(boost::shared_ptr<DiscreteConditional> discreteConditional);
/**
* @brief Construct a new Hybrid Conditional object
*
* @param gaussianMixture Gaussian Mixture Conditional used to create the
* HybridConditional.
*/
HybridConditional( HybridConditional(
boost::shared_ptr<GaussianMixtureConditional> gaussianMixture); boost::shared_ptr<GaussianMixtureConditional> gaussianMixture);
/**
* @brief Return HybridConditional as a GaussianMixtureConditional
*
* @return GaussianMixtureConditional::shared_ptr
*/
GaussianMixtureConditional::shared_ptr asMixture() { GaussianMixtureConditional::shared_ptr asMixture() {
if (!isHybrid_) throw std::invalid_argument("Not a mixture"); if (!isHybrid()) throw std::invalid_argument("Not a mixture");
return boost::static_pointer_cast<GaussianMixtureConditional>(inner); return boost::static_pointer_cast<GaussianMixtureConditional>(inner_);
} }
/**
* @brief Return conditional as a DiscreteConditional
*
* @return DiscreteConditional::shared_ptr
*/
DiscreteConditional::shared_ptr asDiscreteConditional() { DiscreteConditional::shared_ptr asDiscreteConditional() {
if (!isDiscrete_) throw std::invalid_argument("Not a discrete conditional"); if (!isDiscrete()) throw std::invalid_argument("Not a discrete conditional");
return boost::static_pointer_cast<DiscreteConditional>(inner); return boost::static_pointer_cast<DiscreteConditional>(inner_);
} }
boost::shared_ptr<Factor> getInner() { return inner; }
/// @} /// @}
/// @name Testable /// @name Testable
/// @{ /// @{
@ -122,11 +164,10 @@ class GTSAM_EXPORT HybridConditional
/// @} /// @}
friend std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr> // /// Get the type-erased pointer to the inner type
EliminateHybrid(const HybridFactorGraph& factors, boost::shared_ptr<Factor> inner() { return inner_; }
const Ordering& frontalKeys);
}; }; // DiscreteConditional
// DiscreteConditional
// traits // traits
template <> template <>