rename GaussianMixture to HybridGaussianConditional

release/4.3a0
Varun Agrawal 2024-09-13 05:41:24 -04:00
parent 187935407c
commit aef273bce8
27 changed files with 161 additions and 160 deletions

View File

@ -191,13 +191,13 @@ E_{gc}(x,y)=\frac{1}{2}\|Rx+Sy-d\|_{\Sigma}^{2}.\label{eq:gc_error}
\end_layout \end_layout
\begin_layout Subsubsection* \begin_layout Subsubsection*
GaussianMixture HybridGaussianConditional
\end_layout \end_layout
\begin_layout Standard \begin_layout Standard
A A
\emph on \emph on
GaussianMixture HybridGaussianConditional
\emph default \emph default
(maybe to be renamed to (maybe to be renamed to
\emph on \emph on
@ -233,7 +233,7 @@ GaussianConditional
to a set of discrete variables. to a set of discrete variables.
As As
\emph on \emph on
GaussianMixture HybridGaussianConditional
\emph default \emph default
is a is a
\emph on \emph on
@ -324,7 +324,7 @@ The key point here is that
\color inherit \color inherit
is the log-normalization constant for the complete is the log-normalization constant for the complete
\emph on \emph on
GaussianMixture HybridGaussianConditional
\emph default \emph default
across all values of across all values of
\begin_inset Formula $m$ \begin_inset Formula $m$
@ -556,7 +556,7 @@ Analogously, a
\emph on \emph on
HybridGaussianFactor HybridGaussianFactor
\emph default \emph default
typically results from a GaussianMixture by having known values typically results from a HybridGaussianConditional by having known values
\begin_inset Formula $\bar{x}$ \begin_inset Formula $\bar{x}$
\end_inset \end_inset
@ -817,7 +817,7 @@ E_{mf}(y,m)=\frac{1}{2}\|A_{m}y-b_{m}\|_{\Sigma_{mfm}}^{2}=E_{gcm}(\bar{x},y)+K_
\end_inset \end_inset
which is identical to the GaussianMixture error which is identical to the HybridGaussianConditional error
\begin_inset CommandInset ref \begin_inset CommandInset ref
LatexCommand eqref LatexCommand eqref
reference "eq:gm_error" reference "eq:gm_error"

View File

@ -10,7 +10,7 @@
* -------------------------------------------------------------------------- */ * -------------------------------------------------------------------------- */
/** /**
* @file GaussianMixture.cpp * @file HybridGaussianConditional.cpp
* @brief A hybrid conditional in the Conditional Linear Gaussian scheme * @brief A hybrid conditional in the Conditional Linear Gaussian scheme
* @author Fan Jiang * @author Fan Jiang
* @author Varun Agrawal * @author Varun Agrawal
@ -20,7 +20,7 @@
#include <gtsam/base/utilities.h> #include <gtsam/base/utilities.h>
#include <gtsam/discrete/DiscreteValues.h> #include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/hybrid/GaussianMixture.h> #include <gtsam/hybrid/HybridGaussianConditional.h>
#include <gtsam/hybrid/HybridGaussianFactor.h> #include <gtsam/hybrid/HybridGaussianFactor.h>
#include <gtsam/hybrid/HybridValues.h> #include <gtsam/hybrid/HybridValues.h>
#include <gtsam/inference/Conditional-inst.h> #include <gtsam/inference/Conditional-inst.h>
@ -29,10 +29,10 @@
namespace gtsam { namespace gtsam {
GaussianMixture::GaussianMixture( HybridGaussianConditional::HybridGaussianConditional(
const KeyVector &continuousFrontals, const KeyVector &continuousParents, const KeyVector &continuousFrontals, const KeyVector &continuousParents,
const DiscreteKeys &discreteParents, const DiscreteKeys &discreteParents,
const GaussianMixture::Conditionals &conditionals) const HybridGaussianConditional::Conditionals &conditionals)
: BaseFactor(CollectKeys(continuousFrontals, continuousParents), : BaseFactor(CollectKeys(continuousFrontals, continuousParents),
discreteParents), discreteParents),
BaseConditional(continuousFrontals.size()), BaseConditional(continuousFrontals.size()),
@ -50,30 +50,30 @@ GaussianMixture::GaussianMixture(
} }
/* *******************************************************************************/ /* *******************************************************************************/
const GaussianMixture::Conditionals &GaussianMixture::conditionals() const { const HybridGaussianConditional::Conditionals &HybridGaussianConditional::conditionals() const {
return conditionals_; return conditionals_;
} }
/* *******************************************************************************/ /* *******************************************************************************/
GaussianMixture::GaussianMixture( HybridGaussianConditional::HybridGaussianConditional(
KeyVector &&continuousFrontals, KeyVector &&continuousParents, KeyVector &&continuousFrontals, KeyVector &&continuousParents,
DiscreteKeys &&discreteParents, DiscreteKeys &&discreteParents,
std::vector<GaussianConditional::shared_ptr> &&conditionals) std::vector<GaussianConditional::shared_ptr> &&conditionals)
: GaussianMixture(continuousFrontals, continuousParents, discreteParents, : HybridGaussianConditional(continuousFrontals, continuousParents, discreteParents,
Conditionals(discreteParents, conditionals)) {} Conditionals(discreteParents, conditionals)) {}
/* *******************************************************************************/ /* *******************************************************************************/
GaussianMixture::GaussianMixture( HybridGaussianConditional::HybridGaussianConditional(
const KeyVector &continuousFrontals, const KeyVector &continuousParents, const KeyVector &continuousFrontals, const KeyVector &continuousParents,
const DiscreteKeys &discreteParents, const DiscreteKeys &discreteParents,
const std::vector<GaussianConditional::shared_ptr> &conditionals) const std::vector<GaussianConditional::shared_ptr> &conditionals)
: GaussianMixture(continuousFrontals, continuousParents, discreteParents, : HybridGaussianConditional(continuousFrontals, continuousParents, discreteParents,
Conditionals(discreteParents, conditionals)) {} Conditionals(discreteParents, conditionals)) {}
/* *******************************************************************************/ /* *******************************************************************************/
// TODO(dellaert): This is copy/paste: GaussianMixture should be derived from // TODO(dellaert): This is copy/paste: HybridGaussianConditional should be derived from
// GaussianMixtureFactor, no? // GaussianMixtureFactor, no?
GaussianFactorGraphTree GaussianMixture::add( GaussianFactorGraphTree HybridGaussianConditional::add(
const GaussianFactorGraphTree &sum) const { const GaussianFactorGraphTree &sum) const {
using Y = GaussianFactorGraph; using Y = GaussianFactorGraph;
auto add = [](const Y &graph1, const Y &graph2) { auto add = [](const Y &graph1, const Y &graph2) {
@ -86,7 +86,7 @@ GaussianFactorGraphTree GaussianMixture::add(
} }
/* *******************************************************************************/ /* *******************************************************************************/
GaussianFactorGraphTree GaussianMixture::asGaussianFactorGraphTree() const { GaussianFactorGraphTree HybridGaussianConditional::asGaussianFactorGraphTree() const {
auto wrap = [this](const GaussianConditional::shared_ptr &gc) { auto wrap = [this](const GaussianConditional::shared_ptr &gc) {
// First check if conditional has not been pruned // First check if conditional has not been pruned
if (gc) { if (gc) {
@ -109,7 +109,7 @@ GaussianFactorGraphTree GaussianMixture::asGaussianFactorGraphTree() const {
} }
/* *******************************************************************************/ /* *******************************************************************************/
size_t GaussianMixture::nrComponents() const { size_t HybridGaussianConditional::nrComponents() const {
size_t total = 0; size_t total = 0;
conditionals_.visit([&total](const GaussianFactor::shared_ptr &node) { conditionals_.visit([&total](const GaussianFactor::shared_ptr &node) {
if (node) total += 1; if (node) total += 1;
@ -118,7 +118,7 @@ size_t GaussianMixture::nrComponents() const {
} }
/* *******************************************************************************/ /* *******************************************************************************/
GaussianConditional::shared_ptr GaussianMixture::operator()( GaussianConditional::shared_ptr HybridGaussianConditional::operator()(
const DiscreteValues &discreteValues) const { const DiscreteValues &discreteValues) const {
auto &ptr = conditionals_(discreteValues); auto &ptr = conditionals_(discreteValues);
if (!ptr) return nullptr; if (!ptr) return nullptr;
@ -127,11 +127,11 @@ GaussianConditional::shared_ptr GaussianMixture::operator()(
return conditional; return conditional;
else else
throw std::logic_error( throw std::logic_error(
"A GaussianMixture unexpectedly contained a non-conditional"); "A HybridGaussianConditional unexpectedly contained a non-conditional");
} }
/* *******************************************************************************/ /* *******************************************************************************/
bool GaussianMixture::equals(const HybridFactor &lf, double tol) const { bool HybridGaussianConditional::equals(const HybridFactor &lf, double tol) const {
const This *e = dynamic_cast<const This *>(&lf); const This *e = dynamic_cast<const This *>(&lf);
if (e == nullptr) return false; if (e == nullptr) return false;
@ -149,7 +149,7 @@ bool GaussianMixture::equals(const HybridFactor &lf, double tol) const {
} }
/* *******************************************************************************/ /* *******************************************************************************/
void GaussianMixture::print(const std::string &s, void HybridGaussianConditional::print(const std::string &s,
const KeyFormatter &formatter) const { const KeyFormatter &formatter) const {
std::cout << (s.empty() ? "" : s + "\n"); std::cout << (s.empty() ? "" : s + "\n");
if (isContinuous()) std::cout << "Continuous "; if (isContinuous()) std::cout << "Continuous ";
@ -177,7 +177,7 @@ void GaussianMixture::print(const std::string &s,
} }
/* ************************************************************************* */ /* ************************************************************************* */
KeyVector GaussianMixture::continuousParents() const { KeyVector HybridGaussianConditional::continuousParents() const {
// Get all parent keys: // Get all parent keys:
const auto range = parents(); const auto range = parents();
KeyVector continuousParentKeys(range.begin(), range.end()); KeyVector continuousParentKeys(range.begin(), range.end());
@ -193,7 +193,7 @@ KeyVector GaussianMixture::continuousParents() const {
} }
/* ************************************************************************* */ /* ************************************************************************* */
bool GaussianMixture::allFrontalsGiven(const VectorValues &given) const { bool HybridGaussianConditional::allFrontalsGiven(const VectorValues &given) const {
for (auto &&kv : given) { for (auto &&kv : given) {
if (given.find(kv.first) == given.end()) { if (given.find(kv.first) == given.end()) {
return false; return false;
@ -203,11 +203,11 @@ bool GaussianMixture::allFrontalsGiven(const VectorValues &given) const {
} }
/* ************************************************************************* */ /* ************************************************************************* */
std::shared_ptr<GaussianMixtureFactor> GaussianMixture::likelihood( std::shared_ptr<GaussianMixtureFactor> HybridGaussianConditional::likelihood(
const VectorValues &given) const { const VectorValues &given) const {
if (!allFrontalsGiven(given)) { if (!allFrontalsGiven(given)) {
throw std::runtime_error( throw std::runtime_error(
"GaussianMixture::likelihood: given values are missing some frontals."); "HybridGaussianConditional::likelihood: given values are missing some frontals.");
} }
const DiscreteKeys discreteParentKeys = discreteKeys(); const DiscreteKeys discreteParentKeys = discreteKeys();
@ -252,7 +252,7 @@ std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &discreteKeys) {
*/ */
std::function<GaussianConditional::shared_ptr( std::function<GaussianConditional::shared_ptr(
const Assignment<Key> &, const GaussianConditional::shared_ptr &)> const Assignment<Key> &, const GaussianConditional::shared_ptr &)>
GaussianMixture::prunerFunc(const DecisionTreeFactor &discreteProbs) { HybridGaussianConditional::prunerFunc(const DecisionTreeFactor &discreteProbs) {
// Get the discrete keys as sets for the decision tree // Get the discrete keys as sets for the decision tree
// and the gaussian mixture. // and the gaussian mixture.
auto discreteProbsKeySet = DiscreteKeysAsSet(discreteProbs.discreteKeys()); auto discreteProbsKeySet = DiscreteKeysAsSet(discreteProbs.discreteKeys());
@ -303,7 +303,7 @@ GaussianMixture::prunerFunc(const DecisionTreeFactor &discreteProbs) {
} }
/* *******************************************************************************/ /* *******************************************************************************/
void GaussianMixture::prune(const DecisionTreeFactor &discreteProbs) { void HybridGaussianConditional::prune(const DecisionTreeFactor &discreteProbs) {
// Functional which loops over all assignments and create a set of // Functional which loops over all assignments and create a set of
// GaussianConditionals // GaussianConditionals
auto pruner = prunerFunc(discreteProbs); auto pruner = prunerFunc(discreteProbs);
@ -313,7 +313,7 @@ void GaussianMixture::prune(const DecisionTreeFactor &discreteProbs) {
} }
/* *******************************************************************************/ /* *******************************************************************************/
AlgebraicDecisionTree<Key> GaussianMixture::logProbability( AlgebraicDecisionTree<Key> HybridGaussianConditional::logProbability(
const VectorValues &continuousValues) const { const VectorValues &continuousValues) const {
// functor to calculate (double) logProbability value from // functor to calculate (double) logProbability value from
// GaussianConditional. // GaussianConditional.
@ -331,7 +331,7 @@ AlgebraicDecisionTree<Key> GaussianMixture::logProbability(
} }
/* ************************************************************************* */ /* ************************************************************************* */
double GaussianMixture::conditionalError( double HybridGaussianConditional::conditionalError(
const GaussianConditional::shared_ptr &conditional, const GaussianConditional::shared_ptr &conditional,
const VectorValues &continuousValues) const { const VectorValues &continuousValues) const {
// Check if valid pointer // Check if valid pointer
@ -348,7 +348,7 @@ double GaussianMixture::conditionalError(
} }
/* *******************************************************************************/ /* *******************************************************************************/
AlgebraicDecisionTree<Key> GaussianMixture::errorTree( AlgebraicDecisionTree<Key> HybridGaussianConditional::errorTree(
const VectorValues &continuousValues) const { const VectorValues &continuousValues) const {
auto errorFunc = [&](const GaussianConditional::shared_ptr &conditional) { auto errorFunc = [&](const GaussianConditional::shared_ptr &conditional) {
return conditionalError(conditional, continuousValues); return conditionalError(conditional, continuousValues);
@ -358,20 +358,20 @@ AlgebraicDecisionTree<Key> GaussianMixture::errorTree(
} }
/* *******************************************************************************/ /* *******************************************************************************/
double GaussianMixture::error(const HybridValues &values) const { double HybridGaussianConditional::error(const HybridValues &values) const {
// Directly index to get the conditional, no need to build the whole tree. // Directly index to get the conditional, no need to build the whole tree.
auto conditional = conditionals_(values.discrete()); auto conditional = conditionals_(values.discrete());
return conditionalError(conditional, values.continuous()); return conditionalError(conditional, values.continuous());
} }
/* *******************************************************************************/ /* *******************************************************************************/
double GaussianMixture::logProbability(const HybridValues &values) const { double HybridGaussianConditional::logProbability(const HybridValues &values) const {
auto conditional = conditionals_(values.discrete()); auto conditional = conditionals_(values.discrete());
return conditional->logProbability(values.continuous()); return conditional->logProbability(values.continuous());
} }
/* *******************************************************************************/ /* *******************************************************************************/
double GaussianMixture::evaluate(const HybridValues &values) const { double HybridGaussianConditional::evaluate(const HybridValues &values) const {
auto conditional = conditionals_(values.discrete()); auto conditional = conditionals_(values.discrete());
return conditional->evaluate(values.continuous()); return conditional->evaluate(values.continuous());
} }

View File

@ -10,7 +10,7 @@
* -------------------------------------------------------------------------- */ * -------------------------------------------------------------------------- */
/** /**
* @file GaussianMixture.h * @file HybridGaussianConditional.h
* @brief A hybrid conditional in the Conditional Linear Gaussian scheme * @brief A hybrid conditional in the Conditional Linear Gaussian scheme
* @author Fan Jiang * @author Fan Jiang
* @author Varun Agrawal * @author Varun Agrawal
@ -50,14 +50,14 @@ class HybridValues;
* *
* @ingroup hybrid * @ingroup hybrid
*/ */
class GTSAM_EXPORT GaussianMixture class GTSAM_EXPORT HybridGaussianConditional
: public HybridFactor, : public HybridFactor,
public Conditional<HybridFactor, GaussianMixture> { public Conditional<HybridFactor, HybridGaussianConditional> {
public: public:
using This = GaussianMixture; using This = HybridGaussianConditional;
using shared_ptr = std::shared_ptr<GaussianMixture>; using shared_ptr = std::shared_ptr<HybridGaussianConditional>;
using BaseFactor = HybridFactor; using BaseFactor = HybridFactor;
using BaseConditional = Conditional<HybridFactor, GaussianMixture>; using BaseConditional = Conditional<HybridFactor, HybridGaussianConditional>;
/// typedef for Decision Tree of Gaussian Conditionals /// typedef for Decision Tree of Gaussian Conditionals
using Conditionals = DecisionTree<Key, GaussianConditional::shared_ptr>; using Conditionals = DecisionTree<Key, GaussianConditional::shared_ptr>;
@ -67,7 +67,7 @@ class GTSAM_EXPORT GaussianMixture
double logConstant_; ///< log of the normalization constant. double logConstant_; ///< log of the normalization constant.
/** /**
* @brief Convert a GaussianMixture of conditionals into * @brief Convert a HybridGaussianConditional of conditionals into
* a DecisionTree of Gaussian factor graphs. * a DecisionTree of Gaussian factor graphs.
*/ */
GaussianFactorGraphTree asGaussianFactorGraphTree() const; GaussianFactorGraphTree asGaussianFactorGraphTree() const;
@ -88,10 +88,10 @@ class GTSAM_EXPORT GaussianMixture
/// @{ /// @{
/// Default constructor, mainly for serialization. /// Default constructor, mainly for serialization.
GaussianMixture() = default; HybridGaussianConditional() = default;
/** /**
* @brief Construct a new GaussianMixture object. * @brief Construct a new HybridGaussianConditional object.
* *
* @param continuousFrontals the continuous frontals. * @param continuousFrontals the continuous frontals.
* @param continuousParents the continuous parents. * @param continuousParents the continuous parents.
@ -101,7 +101,7 @@ class GTSAM_EXPORT GaussianMixture
* cardinality of the DiscreteKeys in discreteParents, since the * cardinality of the DiscreteKeys in discreteParents, since the
* discreteParents will be used as the labels in the decision tree. * discreteParents will be used as the labels in the decision tree.
*/ */
GaussianMixture(const KeyVector &continuousFrontals, HybridGaussianConditional(const KeyVector &continuousFrontals,
const KeyVector &continuousParents, const KeyVector &continuousParents,
const DiscreteKeys &discreteParents, const DiscreteKeys &discreteParents,
const Conditionals &conditionals); const Conditionals &conditionals);
@ -114,7 +114,7 @@ class GTSAM_EXPORT GaussianMixture
* @param discreteParents Discrete parents variables * @param discreteParents Discrete parents variables
* @param conditionals List of conditionals * @param conditionals List of conditionals
*/ */
GaussianMixture(KeyVector &&continuousFrontals, KeyVector &&continuousParents, HybridGaussianConditional(KeyVector &&continuousFrontals, KeyVector &&continuousParents,
DiscreteKeys &&discreteParents, DiscreteKeys &&discreteParents,
std::vector<GaussianConditional::shared_ptr> &&conditionals); std::vector<GaussianConditional::shared_ptr> &&conditionals);
@ -126,7 +126,7 @@ class GTSAM_EXPORT GaussianMixture
* @param discreteParents Discrete parents variables * @param discreteParents Discrete parents variables
* @param conditionals List of conditionals * @param conditionals List of conditionals
*/ */
GaussianMixture( HybridGaussianConditional(
const KeyVector &continuousFrontals, const KeyVector &continuousParents, const KeyVector &continuousFrontals, const KeyVector &continuousParents,
const DiscreteKeys &discreteParents, const DiscreteKeys &discreteParents,
const std::vector<GaussianConditional::shared_ptr> &conditionals); const std::vector<GaussianConditional::shared_ptr> &conditionals);
@ -140,7 +140,7 @@ class GTSAM_EXPORT GaussianMixture
/// Print utility /// Print utility
void print( void print(
const std::string &s = "GaussianMixture\n", const std::string &s = "HybridGaussianConditional\n",
const KeyFormatter &formatter = DefaultKeyFormatter) const override; const KeyFormatter &formatter = DefaultKeyFormatter) const override;
/// @} /// @}
@ -172,7 +172,7 @@ class GTSAM_EXPORT GaussianMixture
const Conditionals &conditionals() const; const Conditionals &conditionals() const;
/** /**
* @brief Compute logProbability of the GaussianMixture as a tree. * @brief Compute logProbability of the HybridGaussianConditional as a tree.
* *
* @param continuousValues The continuous VectorValues. * @param continuousValues The continuous VectorValues.
* @return AlgebraicDecisionTree<Key> A decision tree with the same keys * @return AlgebraicDecisionTree<Key> A decision tree with the same keys
@ -209,7 +209,7 @@ class GTSAM_EXPORT GaussianMixture
double error(const HybridValues &values) const override; double error(const HybridValues &values) const override;
/** /**
* @brief Compute error of the GaussianMixture as a tree. * @brief Compute error of the HybridGaussianConditional as a tree.
* *
* @param continuousValues The continuous VectorValues. * @param continuousValues The continuous VectorValues.
* @return AlgebraicDecisionTree<Key> A decision tree on the discrete keys * @return AlgebraicDecisionTree<Key> A decision tree on the discrete keys
@ -277,6 +277,6 @@ std::set<DiscreteKey> DiscreteKeysAsSet(const DiscreteKeys &discreteKeys);
// traits // traits
template <> template <>
struct traits<GaussianMixture> : public Testable<GaussianMixture> {}; struct traits<HybridGaussianConditional> : public Testable<HybridGaussianConditional> {};
} // namespace gtsam } // namespace gtsam

View File

@ -168,11 +168,11 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
DecisionTreeFactor prunedDiscreteProbs = DecisionTreeFactor prunedDiscreteProbs =
this->pruneDiscreteConditionals(maxNrLeaves); this->pruneDiscreteConditionals(maxNrLeaves);
/* To prune, we visitWith every leaf in the GaussianMixture. /* To prune, we visitWith every leaf in the HybridGaussianConditional.
* For each leaf, using the assignment we can check the discrete decision tree * For each leaf, using the assignment we can check the discrete decision tree
* for 0.0 probability, then just set the leaf to a nullptr. * for 0.0 probability, then just set the leaf to a nullptr.
* *
* We can later check the GaussianMixture for just nullptrs. * We can later check the HybridGaussianConditional for just nullptrs.
*/ */
HybridBayesNet prunedBayesNetFragment; HybridBayesNet prunedBayesNetFragment;
@ -182,14 +182,14 @@ HybridBayesNet HybridBayesNet::prune(size_t maxNrLeaves) {
for (auto &&conditional : *this) { for (auto &&conditional : *this) {
if (auto gm = conditional->asMixture()) { if (auto gm = conditional->asMixture()) {
// Make a copy of the Gaussian mixture and prune it! // Make a copy of the Gaussian mixture and prune it!
auto prunedGaussianMixture = std::make_shared<GaussianMixture>(*gm); auto prunedGaussianMixture = std::make_shared<HybridGaussianConditional>(*gm);
prunedGaussianMixture->prune(prunedDiscreteProbs); // imperative :-( prunedGaussianMixture->prune(prunedDiscreteProbs); // imperative :-(
// Type-erase and add to the pruned Bayes Net fragment. // Type-erase and add to the pruned Bayes Net fragment.
prunedBayesNetFragment.push_back(prunedGaussianMixture); prunedBayesNetFragment.push_back(prunedGaussianMixture);
} else { } else {
// Add the non-GaussianMixture conditional // Add the non-HybridGaussianConditional conditional
prunedBayesNetFragment.push_back(conditional); prunedBayesNetFragment.push_back(conditional);
} }
} }

View File

@ -79,7 +79,7 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
* *
* Example: * Example:
* auto shared_ptr_to_a_conditional = * auto shared_ptr_to_a_conditional =
* std::make_shared<GaussianMixture>(...); * std::make_shared<HybridGaussianConditional>(...);
* hbn.push_back(shared_ptr_to_a_conditional); * hbn.push_back(shared_ptr_to_a_conditional);
*/ */
void push_back(HybridConditional &&conditional) { void push_back(HybridConditional &&conditional) {
@ -106,7 +106,7 @@ class GTSAM_EXPORT HybridBayesNet : public BayesNet<HybridConditional> {
* Preferred: Emplace a conditional directly using arguments. * Preferred: Emplace a conditional directly using arguments.
* *
* Examples: * Examples:
* hbn.emplace_shared<GaussianMixture>(...))); * hbn.emplace_shared<HybridGaussianConditional>(...)));
* hbn.emplace_shared<GaussianConditional>(...))); * hbn.emplace_shared<GaussianConditional>(...)));
* hbn.emplace_shared<DiscreteConditional>(...))); * hbn.emplace_shared<DiscreteConditional>(...)));
*/ */

View File

@ -55,7 +55,7 @@ HybridConditional::HybridConditional(
/* ************************************************************************ */ /* ************************************************************************ */
HybridConditional::HybridConditional( HybridConditional::HybridConditional(
const std::shared_ptr<GaussianMixture> &gaussianMixture) const std::shared_ptr<HybridGaussianConditional> &gaussianMixture)
: BaseFactor(KeyVector(gaussianMixture->keys().begin(), : BaseFactor(KeyVector(gaussianMixture->keys().begin(),
gaussianMixture->keys().begin() + gaussianMixture->keys().begin() +
gaussianMixture->nrContinuous()), gaussianMixture->nrContinuous()),

View File

@ -18,7 +18,7 @@
#pragma once #pragma once
#include <gtsam/discrete/DiscreteConditional.h> #include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/hybrid/GaussianMixture.h> #include <gtsam/hybrid/HybridGaussianConditional.h>
#include <gtsam/hybrid/HybridFactor.h> #include <gtsam/hybrid/HybridFactor.h>
#include <gtsam/hybrid/HybridGaussianFactorGraph.h> #include <gtsam/hybrid/HybridGaussianFactorGraph.h>
#include <gtsam/inference/Conditional.h> #include <gtsam/inference/Conditional.h>
@ -39,7 +39,7 @@ namespace gtsam {
* As a type-erased variant of: * As a type-erased variant of:
* - DiscreteConditional * - DiscreteConditional
* - GaussianConditional * - GaussianConditional
* - GaussianMixture * - HybridGaussianConditional
* *
* The reason why this is important is that `Conditional<T>` is a CRTP class. * The reason why this is important is that `Conditional<T>` is a CRTP class.
* CRTP is static polymorphism such that all CRTP classes, while bearing the * CRTP is static polymorphism such that all CRTP classes, while bearing the
@ -127,7 +127,7 @@ class GTSAM_EXPORT HybridConditional
* @param gaussianMixture Gaussian Mixture Conditional used to create the * @param gaussianMixture Gaussian Mixture Conditional used to create the
* HybridConditional. * HybridConditional.
*/ */
HybridConditional(const std::shared_ptr<GaussianMixture>& gaussianMixture); HybridConditional(const std::shared_ptr<HybridGaussianConditional>& gaussianMixture);
/// @} /// @}
/// @name Testable /// @name Testable
@ -146,12 +146,12 @@ class GTSAM_EXPORT HybridConditional
/// @{ /// @{
/** /**
* @brief Return HybridConditional as a GaussianMixture * @brief Return HybridConditional as a HybridGaussianConditional
* @return nullptr if not a mixture * @return nullptr if not a mixture
* @return GaussianMixture::shared_ptr otherwise * @return HybridGaussianConditional::shared_ptr otherwise
*/ */
GaussianMixture::shared_ptr asMixture() const { HybridGaussianConditional::shared_ptr asMixture() const {
return std::dynamic_pointer_cast<GaussianMixture>(inner_); return std::dynamic_pointer_cast<HybridGaussianConditional>(inner_);
} }
/** /**
@ -222,8 +222,8 @@ class GTSAM_EXPORT HybridConditional
boost::serialization::void_cast_register<GaussianConditional, Factor>( boost::serialization::void_cast_register<GaussianConditional, Factor>(
static_cast<GaussianConditional*>(NULL), static_cast<Factor*>(NULL)); static_cast<GaussianConditional*>(NULL), static_cast<Factor*>(NULL));
} else { } else {
boost::serialization::void_cast_register<GaussianMixture, Factor>( boost::serialization::void_cast_register<HybridGaussianConditional, Factor>(
static_cast<GaussianMixture*>(NULL), static_cast<Factor*>(NULL)); static_cast<HybridGaussianConditional*>(NULL), static_cast<Factor*>(NULL));
} }
} }
#endif #endif

View File

@ -47,7 +47,7 @@ DiscreteKeys CollectDiscreteKeys(const DiscreteKeys &key1,
* Examples: * Examples:
* - HybridNonlinearFactor * - HybridNonlinearFactor
* - HybridGaussianFactor * - HybridGaussianFactor
* - GaussianMixture * - HybridGaussianConditional
* *
* @ingroup hybrid * @ingroup hybrid
*/ */

View File

@ -23,7 +23,7 @@
#include <gtsam/discrete/DiscreteEliminationTree.h> #include <gtsam/discrete/DiscreteEliminationTree.h>
#include <gtsam/discrete/DiscreteFactorGraph.h> #include <gtsam/discrete/DiscreteFactorGraph.h>
#include <gtsam/discrete/DiscreteJunctionTree.h> #include <gtsam/discrete/DiscreteJunctionTree.h>
#include <gtsam/hybrid/GaussianMixture.h> #include <gtsam/hybrid/HybridGaussianConditional.h>
#include <gtsam/hybrid/HybridGaussianFactor.h> #include <gtsam/hybrid/HybridGaussianFactor.h>
#include <gtsam/hybrid/HybridConditional.h> #include <gtsam/hybrid/HybridConditional.h>
#include <gtsam/hybrid/HybridEliminationTree.h> #include <gtsam/hybrid/HybridEliminationTree.h>
@ -180,7 +180,7 @@ GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
result = addGaussian(result, gf); result = addGaussian(result, gf);
} else if (auto gmf = dynamic_pointer_cast<HybridGaussianFactor>(f)) { } else if (auto gmf = dynamic_pointer_cast<HybridGaussianFactor>(f)) {
result = gmf->add(result); result = gmf->add(result);
} else if (auto gm = dynamic_pointer_cast<GaussianMixture>(f)) { } else if (auto gm = dynamic_pointer_cast<HybridGaussianConditional>(f)) {
result = gm->add(result); result = gm->add(result);
} else if (auto hc = dynamic_pointer_cast<HybridConditional>(f)) { } else if (auto hc = dynamic_pointer_cast<HybridConditional>(f)) {
if (auto gm = hc->asMixture()) { if (auto gm = hc->asMixture()) {
@ -408,10 +408,10 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
: createGaussianMixtureFactor(eliminationResults, continuousSeparator, : createGaussianMixtureFactor(eliminationResults, continuousSeparator,
discreteSeparator); discreteSeparator);
// Create the GaussianMixture from the conditionals // Create the HybridGaussianConditional from the conditionals
GaussianMixture::Conditionals conditionals( HybridGaussianConditional::Conditionals conditionals(
eliminationResults, [](const Result &pair) { return pair.first; }); eliminationResults, [](const Result &pair) { return pair.first; });
auto gaussianMixture = std::make_shared<GaussianMixture>( auto gaussianMixture = std::make_shared<HybridGaussianConditional>(
frontalKeys, continuousSeparator, discreteSeparator, conditionals); frontalKeys, continuousSeparator, discreteSeparator, conditionals);
return {std::make_shared<HybridConditional>(gaussianMixture), newFactor}; return {std::make_shared<HybridConditional>(gaussianMixture), newFactor};
@ -458,7 +458,7 @@ EliminateHybrid(const HybridGaussianFactorGraph &factors,
// Because of all these reasons, we carefully consider how to // Because of all these reasons, we carefully consider how to
// implement the hybrid factors so that we do not get poor performance. // implement the hybrid factors so that we do not get poor performance.
// The first thing is how to represent the GaussianMixture. // The first thing is how to represent the HybridGaussianConditional.
// A very possible scenario is that the incoming factors will have different // A very possible scenario is that the incoming factors will have different
// levels of discrete keys. For example, imagine we are going to eliminate the // levels of discrete keys. For example, imagine we are going to eliminate the
// fragment: $\phi(x1,c1,c2)$, $\phi(x1,c2,c3)$, which is perfectly valid. // fragment: $\phi(x1,c1,c2)$, $\phi(x1,c2,c3)$, which is perfectly valid.
@ -599,7 +599,7 @@ GaussianFactorGraph HybridGaussianFactorGraph::operator()(
gfg.push_back(gf); gfg.push_back(gf);
} else if (auto gmf = std::dynamic_pointer_cast<HybridGaussianFactor>(f)) { } else if (auto gmf = std::dynamic_pointer_cast<HybridGaussianFactor>(f)) {
gfg.push_back((*gmf)(assignment)); gfg.push_back((*gmf)(assignment));
} else if (auto gm = dynamic_pointer_cast<GaussianMixture>(f)) { } else if (auto gm = dynamic_pointer_cast<HybridGaussianConditional>(f)) {
gfg.push_back((*gm)(assignment)); gfg.push_back((*gm)(assignment));
} else { } else {
continue; continue;

View File

@ -18,7 +18,7 @@
#include <gtsam/discrete/DecisionTreeFactor.h> #include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/TableFactor.h> #include <gtsam/discrete/TableFactor.h>
#include <gtsam/hybrid/GaussianMixture.h> #include <gtsam/hybrid/HybridGaussianConditional.h>
#include <gtsam/hybrid/HybridGaussianFactorGraph.h> #include <gtsam/hybrid/HybridGaussianFactorGraph.h>
#include <gtsam/hybrid/HybridNonlinearFactorGraph.h> #include <gtsam/hybrid/HybridNonlinearFactorGraph.h>
#include <gtsam/hybrid/HybridNonlinearFactor.h> #include <gtsam/hybrid/HybridNonlinearFactor.h>
@ -80,7 +80,7 @@ void HybridNonlinearFactorGraph::printErrors(
gmf->errorTree(values.continuous()).print("", keyFormatter); gmf->errorTree(values.continuous()).print("", keyFormatter);
std::cout << std::endl; std::cout << std::endl;
} }
} else if (auto gm = std::dynamic_pointer_cast<GaussianMixture>(factor)) { } else if (auto gm = std::dynamic_pointer_cast<HybridGaussianConditional>(factor)) {
if (factor == nullptr) { if (factor == nullptr) {
std::cout << "nullptr" std::cout << "nullptr"
<< "\n"; << "\n";
@ -163,7 +163,7 @@ HybridGaussianFactorGraph::shared_ptr HybridNonlinearFactorGraph::linearize(
linearFG->push_back(f); linearFG->push_back(f);
} else if (auto gmf = dynamic_pointer_cast<HybridGaussianFactor>(f)) { } else if (auto gmf = dynamic_pointer_cast<HybridGaussianFactor>(f)) {
linearFG->push_back(gmf); linearFG->push_back(gmf);
} else if (auto gm = dynamic_pointer_cast<GaussianMixture>(f)) { } else if (auto gm = dynamic_pointer_cast<HybridGaussianConditional>(f)) {
linearFG->push_back(gm); linearFG->push_back(gm);
} else if (dynamic_pointer_cast<GaussianFactor>(f)) { } else if (dynamic_pointer_cast<GaussianFactor>(f)) {
linearFG->push_back(f); linearFG->push_back(f);

View File

@ -138,7 +138,7 @@ HybridSmoother::addConditionals(const HybridGaussianFactorGraph &originalGraph,
} }
/* ************************************************************************* */ /* ************************************************************************* */
GaussianMixture::shared_ptr HybridSmoother::gaussianMixture( HybridGaussianConditional::shared_ptr HybridSmoother::gaussianMixture(
size_t index) const { size_t index) const {
return hybridBayesNet_.at(index)->asMixture(); return hybridBayesNet_.at(index)->asMixture();
} }

View File

@ -69,7 +69,7 @@ class GTSAM_EXPORT HybridSmoother {
const HybridBayesNet& hybridBayesNet, const Ordering& ordering) const; const HybridBayesNet& hybridBayesNet, const Ordering& ordering) const;
/// Get the Gaussian Mixture from the Bayes Net posterior at `index`. /// Get the Gaussian Mixture from the Bayes Net posterior at `index`.
GaussianMixture::shared_ptr gaussianMixture(size_t index) const; HybridGaussianConditional::shared_ptr gaussianMixture(size_t index) const;
/// Return the Bayes Net posterior. /// Return the Bayes Net posterior.
const HybridBayesNet& hybridBayesNet() const; const HybridBayesNet& hybridBayesNet() const;

View File

@ -65,7 +65,7 @@ virtual class HybridConditional {
double logProbability(const gtsam::HybridValues& values) const; double logProbability(const gtsam::HybridValues& values) const;
double evaluate(const gtsam::HybridValues& values) const; double evaluate(const gtsam::HybridValues& values) const;
double operator()(const gtsam::HybridValues& values) const; double operator()(const gtsam::HybridValues& values) const;
gtsam::GaussianMixture* asMixture() const; gtsam::HybridGaussianConditional* asMixture() const;
gtsam::GaussianConditional* asGaussian() const; gtsam::GaussianConditional* asGaussian() const;
gtsam::DiscreteConditional* asDiscrete() const; gtsam::DiscreteConditional* asDiscrete() const;
gtsam::Factor* inner(); gtsam::Factor* inner();
@ -84,9 +84,9 @@ class HybridGaussianFactor : gtsam::HybridFactor {
gtsam::DefaultKeyFormatter) const; gtsam::DefaultKeyFormatter) const;
}; };
#include <gtsam/hybrid/GaussianMixture.h> #include <gtsam/hybrid/HybridGaussianConditional.h>
class GaussianMixture : gtsam::HybridFactor { class HybridGaussianConditional : gtsam::HybridFactor {
GaussianMixture(const gtsam::KeyVector& continuousFrontals, HybridGaussianConditional(const gtsam::KeyVector& continuousFrontals,
const gtsam::KeyVector& continuousParents, const gtsam::KeyVector& continuousParents,
const gtsam::DiscreteKeys& discreteParents, const gtsam::DiscreteKeys& discreteParents,
const std::vector<gtsam::GaussianConditional::shared_ptr>& const std::vector<gtsam::GaussianConditional::shared_ptr>&
@ -97,7 +97,7 @@ class GaussianMixture : gtsam::HybridFactor {
double logProbability(const gtsam::HybridValues& values) const; double logProbability(const gtsam::HybridValues& values) const;
double evaluate(const gtsam::HybridValues& values) const; double evaluate(const gtsam::HybridValues& values) const;
void print(string s = "GaussianMixture\n", void print(string s = "HybridGaussianConditional\n",
const gtsam::KeyFormatter& keyFormatter = const gtsam::KeyFormatter& keyFormatter =
gtsam::DefaultKeyFormatter) const; gtsam::DefaultKeyFormatter) const;
}; };
@ -131,7 +131,7 @@ class HybridBayesTree {
#include <gtsam/hybrid/HybridBayesNet.h> #include <gtsam/hybrid/HybridBayesNet.h>
class HybridBayesNet { class HybridBayesNet {
HybridBayesNet(); HybridBayesNet();
void push_back(const gtsam::GaussianMixture* s); void push_back(const gtsam::HybridGaussianConditional* s);
void push_back(const gtsam::GaussianConditional* s); void push_back(const gtsam::GaussianConditional* s);
void push_back(const gtsam::DiscreteConditional* s); void push_back(const gtsam::DiscreteConditional* s);

View File

@ -43,7 +43,7 @@ inline HybridBayesNet createHybridBayesNet(size_t num_measurements = 1,
// Create Gaussian mixture z_i = x0 + noise for each measurement. // Create Gaussian mixture z_i = x0 + noise for each measurement.
for (size_t i = 0; i < num_measurements; i++) { for (size_t i = 0; i < num_measurements; i++) {
const auto mode_i = manyModes ? DiscreteKey{M(i), 2} : mode; const auto mode_i = manyModes ? DiscreteKey{M(i), 2} : mode;
bayesNet.emplace_shared<GaussianMixture>( bayesNet.emplace_shared<HybridGaussianConditional>(
KeyVector{Z(i)}, KeyVector{X(0)}, DiscreteKeys{mode_i}, KeyVector{Z(i)}, KeyVector{X(0)}, DiscreteKeys{mode_i},
std::vector{GaussianConditional::sharedMeanAndStddev(Z(i), I_1x1, X(0), std::vector{GaussianConditional::sharedMeanAndStddev(Z(i), I_1x1, X(0),
Z_1x1, 0.5), Z_1x1, 0.5),

View File

@ -11,7 +11,7 @@
/** /**
* @file testGaussianMixture.cpp * @file testGaussianMixture.cpp
* @brief Unit tests for GaussianMixture class * @brief Unit tests for HybridGaussianConditional class
* @author Varun Agrawal * @author Varun Agrawal
* @author Fan Jiang * @author Fan Jiang
* @author Frank Dellaert * @author Frank Dellaert
@ -19,7 +19,7 @@
*/ */
#include <gtsam/discrete/DiscreteValues.h> #include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/hybrid/GaussianMixture.h> #include <gtsam/hybrid/HybridGaussianConditional.h>
#include <gtsam/hybrid/HybridGaussianFactor.h> #include <gtsam/hybrid/HybridGaussianFactor.h>
#include <gtsam/hybrid/HybridValues.h> #include <gtsam/hybrid/HybridValues.h>
#include <gtsam/inference/Symbol.h> #include <gtsam/inference/Symbol.h>
@ -46,19 +46,19 @@ static const HybridValues hv1{vv, assignment1};
/* ************************************************************************* */ /* ************************************************************************* */
namespace equal_constants { namespace equal_constants {
// Create a simple GaussianMixture // Create a simple HybridGaussianConditional
const double commonSigma = 2.0; const double commonSigma = 2.0;
const std::vector<GaussianConditional::shared_ptr> conditionals{ const std::vector<GaussianConditional::shared_ptr> conditionals{
GaussianConditional::sharedMeanAndStddev(Z(0), I_1x1, X(0), Vector1(0.0), GaussianConditional::sharedMeanAndStddev(Z(0), I_1x1, X(0), Vector1(0.0),
commonSigma), commonSigma),
GaussianConditional::sharedMeanAndStddev(Z(0), I_1x1, X(0), Vector1(0.0), GaussianConditional::sharedMeanAndStddev(Z(0), I_1x1, X(0), Vector1(0.0),
commonSigma)}; commonSigma)};
const GaussianMixture mixture({Z(0)}, {X(0)}, {mode}, conditionals); const HybridGaussianConditional mixture({Z(0)}, {X(0)}, {mode}, conditionals);
} // namespace equal_constants } // namespace equal_constants
/* ************************************************************************* */ /* ************************************************************************* */
/// Check that invariants hold /// Check that invariants hold
TEST(GaussianMixture, Invariants) { TEST(HybridGaussianConditional, Invariants) {
using namespace equal_constants; using namespace equal_constants;
// Check that the mixture normalization constant is the max of all constants // Check that the mixture normalization constant is the max of all constants
@ -67,13 +67,13 @@ TEST(GaussianMixture, Invariants) {
EXPECT_DOUBLES_EQUAL(K, conditionals[0]->logNormalizationConstant(), 1e-8); EXPECT_DOUBLES_EQUAL(K, conditionals[0]->logNormalizationConstant(), 1e-8);
EXPECT_DOUBLES_EQUAL(K, conditionals[1]->logNormalizationConstant(), 1e-8); EXPECT_DOUBLES_EQUAL(K, conditionals[1]->logNormalizationConstant(), 1e-8);
EXPECT(GaussianMixture::CheckInvariants(mixture, hv0)); EXPECT(HybridGaussianConditional::CheckInvariants(mixture, hv0));
EXPECT(GaussianMixture::CheckInvariants(mixture, hv1)); EXPECT(HybridGaussianConditional::CheckInvariants(mixture, hv1));
} }
/* ************************************************************************* */ /* ************************************************************************* */
/// Check LogProbability. /// Check LogProbability.
TEST(GaussianMixture, LogProbability) { TEST(HybridGaussianConditional, LogProbability) {
using namespace equal_constants; using namespace equal_constants;
auto actual = mixture.logProbability(vv); auto actual = mixture.logProbability(vv);
@ -95,7 +95,7 @@ TEST(GaussianMixture, LogProbability) {
/* ************************************************************************* */ /* ************************************************************************* */
/// Check error. /// Check error.
TEST(GaussianMixture, Error) { TEST(HybridGaussianConditional, Error) {
using namespace equal_constants; using namespace equal_constants;
auto actual = mixture.errorTree(vv); auto actual = mixture.errorTree(vv);
@ -118,7 +118,7 @@ TEST(GaussianMixture, Error) {
/* ************************************************************************* */ /* ************************************************************************* */
/// Check that the likelihood is proportional to the conditional density given /// Check that the likelihood is proportional to the conditional density given
/// the measurements. /// the measurements.
TEST(GaussianMixture, Likelihood) { TEST(HybridGaussianConditional, Likelihood) {
using namespace equal_constants; using namespace equal_constants;
// Compute likelihood // Compute likelihood
@ -147,19 +147,19 @@ TEST(GaussianMixture, Likelihood) {
/* ************************************************************************* */ /* ************************************************************************* */
namespace mode_dependent_constants { namespace mode_dependent_constants {
// Create a GaussianMixture with mode-dependent noise models. // Create a HybridGaussianConditional with mode-dependent noise models.
// 0 is low-noise, 1 is high-noise. // 0 is low-noise, 1 is high-noise.
const std::vector<GaussianConditional::shared_ptr> conditionals{ const std::vector<GaussianConditional::shared_ptr> conditionals{
GaussianConditional::sharedMeanAndStddev(Z(0), I_1x1, X(0), Vector1(0.0), GaussianConditional::sharedMeanAndStddev(Z(0), I_1x1, X(0), Vector1(0.0),
0.5), 0.5),
GaussianConditional::sharedMeanAndStddev(Z(0), I_1x1, X(0), Vector1(0.0), GaussianConditional::sharedMeanAndStddev(Z(0), I_1x1, X(0), Vector1(0.0),
3.0)}; 3.0)};
const GaussianMixture mixture({Z(0)}, {X(0)}, {mode}, conditionals); const HybridGaussianConditional mixture({Z(0)}, {X(0)}, {mode}, conditionals);
} // namespace mode_dependent_constants } // namespace mode_dependent_constants
/* ************************************************************************* */ /* ************************************************************************* */
// Create a test for continuousParents. // Create a test for continuousParents.
TEST(GaussianMixture, ContinuousParents) { TEST(HybridGaussianConditional, ContinuousParents) {
using namespace mode_dependent_constants; using namespace mode_dependent_constants;
const KeyVector continuousParentKeys = mixture.continuousParents(); const KeyVector continuousParentKeys = mixture.continuousParents();
// Check that the continuous parent keys are correct: // Check that the continuous parent keys are correct:
@ -170,7 +170,7 @@ TEST(GaussianMixture, ContinuousParents) {
/* ************************************************************************* */ /* ************************************************************************* */
/// Check that the likelihood is proportional to the conditional density given /// Check that the likelihood is proportional to the conditional density given
/// the measurements. /// the measurements.
TEST(GaussianMixture, Likelihood2) { TEST(HybridGaussianConditional, Likelihood2) {
using namespace mode_dependent_constants; using namespace mode_dependent_constants;
// Compute likelihood // Compute likelihood

View File

@ -20,7 +20,7 @@
#include <gtsam/base/TestableAssertions.h> #include <gtsam/base/TestableAssertions.h>
#include <gtsam/discrete/DiscreteValues.h> #include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/hybrid/GaussianMixture.h> #include <gtsam/hybrid/HybridGaussianConditional.h>
#include <gtsam/hybrid/HybridGaussianFactor.h> #include <gtsam/hybrid/HybridGaussianFactor.h>
#include <gtsam/hybrid/HybridBayesNet.h> #include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridGaussianFactorGraph.h> #include <gtsam/hybrid/HybridGaussianFactorGraph.h>
@ -144,7 +144,7 @@ Hybrid [x1 x2; 1]{
} }
/* ************************************************************************* */ /* ************************************************************************* */
TEST(HybridGaussianFactor, GaussianMixture) { TEST(HybridGaussianFactor, HybridGaussianConditional) {
KeyVector keys; KeyVector keys;
keys.push_back(X(0)); keys.push_back(X(0));
keys.push_back(X(1)); keys.push_back(X(1));
@ -154,8 +154,8 @@ TEST(HybridGaussianFactor, GaussianMixture) {
dKeys.emplace_back(M(1), 2); dKeys.emplace_back(M(1), 2);
auto gaussians = std::make_shared<GaussianConditional>(); auto gaussians = std::make_shared<GaussianConditional>();
GaussianMixture::Conditionals conditionals(gaussians); HybridGaussianConditional::Conditionals conditionals(gaussians);
GaussianMixture gm({}, keys, dKeys, conditionals); HybridGaussianConditional gm({}, keys, dKeys, conditionals);
EXPECT_LONGS_EQUAL(2, gm.discreteKeys().size()); EXPECT_LONGS_EQUAL(2, gm.discreteKeys().size());
} }
@ -229,7 +229,7 @@ static HybridBayesNet GetGaussianMixtureModel(double mu0, double mu1,
c1 = make_shared<GaussianConditional>(z, Vector1(mu1), I_1x1, model1); c1 = make_shared<GaussianConditional>(z, Vector1(mu1), I_1x1, model1);
HybridBayesNet hbn; HybridBayesNet hbn;
hbn.emplace_shared<GaussianMixture>(KeyVector{z}, KeyVector{}, hbn.emplace_shared<HybridGaussianConditional>(KeyVector{z}, KeyVector{},
DiscreteKeys{m}, std::vector{c0, c1}); DiscreteKeys{m}, std::vector{c0, c1});
auto mixing = make_shared<DiscreteConditional>(m, "0.5/0.5"); auto mixing = make_shared<DiscreteConditional>(m, "0.5/0.5");
@ -413,7 +413,7 @@ static HybridBayesNet CreateBayesNet(double mu0, double mu1, double sigma0,
c1 = make_shared<GaussianConditional>(x1, Vector1(mu1), I_1x1, x0, c1 = make_shared<GaussianConditional>(x1, Vector1(mu1), I_1x1, x0,
-I_1x1, model1); -I_1x1, model1);
auto motion = std::make_shared<GaussianMixture>( auto motion = std::make_shared<HybridGaussianConditional>(
KeyVector{x1}, KeyVector{x0}, DiscreteKeys{m1}, std::vector{c0, c1}); KeyVector{x1}, KeyVector{x0}, DiscreteKeys{m1}, std::vector{c0, c1});
hbn.push_back(motion); hbn.push_back(motion);

View File

@ -107,7 +107,7 @@ TEST(HybridBayesNet, evaluateHybrid) {
// Create hybrid Bayes net. // Create hybrid Bayes net.
HybridBayesNet bayesNet; HybridBayesNet bayesNet;
bayesNet.push_back(continuousConditional); bayesNet.push_back(continuousConditional);
bayesNet.emplace_shared<GaussianMixture>( bayesNet.emplace_shared<HybridGaussianConditional>(
KeyVector{X(1)}, KeyVector{}, DiscreteKeys{Asia}, KeyVector{X(1)}, KeyVector{}, DiscreteKeys{Asia},
std::vector{conditional0, conditional1}); std::vector{conditional0, conditional1});
bayesNet.emplace_shared<DiscreteConditional>(Asia, "99/1"); bayesNet.emplace_shared<DiscreteConditional>(Asia, "99/1");
@ -168,7 +168,7 @@ TEST(HybridBayesNet, Error) {
conditional1 = std::make_shared<GaussianConditional>( conditional1 = std::make_shared<GaussianConditional>(
X(1), Vector1::Constant(2), I_1x1, model1); X(1), Vector1::Constant(2), I_1x1, model1);
auto gm = std::make_shared<GaussianMixture>( auto gm = std::make_shared<HybridGaussianConditional>(
KeyVector{X(1)}, KeyVector{}, DiscreteKeys{Asia}, KeyVector{X(1)}, KeyVector{}, DiscreteKeys{Asia},
std::vector{conditional0, conditional1}); std::vector{conditional0, conditional1});
// Create hybrid Bayes net. // Create hybrid Bayes net.

View File

@ -43,9 +43,9 @@ TEST(HybridConditional, Invariants) {
auto hc0 = bn.at(0); auto hc0 = bn.at(0);
CHECK(hc0->isHybrid()); CHECK(hc0->isHybrid());
// Check invariants as a GaussianMixture. // Check invariants as a HybridGaussianConditional.
const auto mixture = hc0->asMixture(); const auto mixture = hc0->asMixture();
EXPECT(GaussianMixture::CheckInvariants(*mixture, values)); EXPECT(HybridGaussianConditional::CheckInvariants(*mixture, values));
// Check invariants as a HybridConditional. // Check invariants as a HybridConditional.
EXPECT(HybridConditional::CheckInvariants(*hc0, values)); EXPECT(HybridConditional::CheckInvariants(*hc0, values));

View File

@ -616,7 +616,7 @@ TEST(HybridEstimation, ModeSelection) {
GaussianConditional::sharedMeanAndStddev(Z(0), -I_1x1, X(0), Z_1x1, 0.1)); GaussianConditional::sharedMeanAndStddev(Z(0), -I_1x1, X(0), Z_1x1, 0.1));
bn.push_back( bn.push_back(
GaussianConditional::sharedMeanAndStddev(Z(0), -I_1x1, X(1), Z_1x1, 0.1)); GaussianConditional::sharedMeanAndStddev(Z(0), -I_1x1, X(1), Z_1x1, 0.1));
bn.emplace_shared<GaussianMixture>( bn.emplace_shared<HybridGaussianConditional>(
KeyVector{Z(0)}, KeyVector{X(0), X(1)}, DiscreteKeys{mode}, KeyVector{Z(0)}, KeyVector{X(0), X(1)}, DiscreteKeys{mode},
std::vector{GaussianConditional::sharedMeanAndStddev( std::vector{GaussianConditional::sharedMeanAndStddev(
Z(0), I_1x1, X(0), -I_1x1, X(1), Z_1x1, noise_loose), Z(0), I_1x1, X(0), -I_1x1, X(1), Z_1x1, noise_loose),
@ -647,7 +647,7 @@ TEST(HybridEstimation, ModeSelection2) {
GaussianConditional::sharedMeanAndStddev(Z(0), -I_3x3, X(0), Z_3x1, 0.1)); GaussianConditional::sharedMeanAndStddev(Z(0), -I_3x3, X(0), Z_3x1, 0.1));
bn.push_back( bn.push_back(
GaussianConditional::sharedMeanAndStddev(Z(0), -I_3x3, X(1), Z_3x1, 0.1)); GaussianConditional::sharedMeanAndStddev(Z(0), -I_3x3, X(1), Z_3x1, 0.1));
bn.emplace_shared<GaussianMixture>( bn.emplace_shared<HybridGaussianConditional>(
KeyVector{Z(0)}, KeyVector{X(0), X(1)}, DiscreteKeys{mode}, KeyVector{Z(0)}, KeyVector{X(0), X(1)}, DiscreteKeys{mode},
std::vector{GaussianConditional::sharedMeanAndStddev( std::vector{GaussianConditional::sharedMeanAndStddev(
Z(0), I_3x3, X(0), -I_3x3, X(1), Z_3x1, noise_loose), Z(0), I_3x3, X(0), -I_3x3, X(1), Z_3x1, noise_loose),

View File

@ -21,7 +21,7 @@
#include <gtsam/discrete/DecisionTreeFactor.h> #include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteKey.h> #include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/discrete/DiscreteValues.h> #include <gtsam/discrete/DiscreteValues.h>
#include <gtsam/hybrid/GaussianMixture.h> #include <gtsam/hybrid/HybridGaussianConditional.h>
#include <gtsam/hybrid/HybridGaussianFactor.h> #include <gtsam/hybrid/HybridGaussianFactor.h>
#include <gtsam/hybrid/HybridBayesNet.h> #include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridBayesTree.h> #include <gtsam/hybrid/HybridBayesTree.h>
@ -71,8 +71,8 @@ TEST(HybridGaussianFactorGraph, Creation) {
// Define a gaussian mixture conditional P(x0|x1, c0) and add it to the factor // Define a gaussian mixture conditional P(x0|x1, c0) and add it to the factor
// graph // graph
GaussianMixture gm({X(0)}, {X(1)}, DiscreteKeys(DiscreteKey{M(0), 2}), HybridGaussianConditional gm({X(0)}, {X(1)}, DiscreteKeys(DiscreteKey{M(0), 2}),
GaussianMixture::Conditionals( HybridGaussianConditional::Conditionals(
M(0), M(0),
std::make_shared<GaussianConditional>( std::make_shared<GaussianConditional>(
X(0), Z_3x1, I_3x3, X(1), I_3x3), X(0), Z_3x1, I_3x3, X(1), I_3x3),
@ -681,7 +681,7 @@ TEST(HybridGaussianFactorGraph, ErrorTreeWithConditional) {
x0, -I_1x1, model0), x0, -I_1x1, model0),
c1 = make_shared<GaussianConditional>(f01, Vector1(mu), I_1x1, x1, I_1x1, c1 = make_shared<GaussianConditional>(f01, Vector1(mu), I_1x1, x1, I_1x1,
x0, -I_1x1, model1); x0, -I_1x1, model1);
hbn.emplace_shared<GaussianMixture>(KeyVector{f01}, KeyVector{x0, x1}, hbn.emplace_shared<HybridGaussianConditional>(KeyVector{f01}, KeyVector{x0, x1},
DiscreteKeys{m1}, std::vector{c0, c1}); DiscreteKeys{m1}, std::vector{c0, c1});
// Discrete uniform prior. // Discrete uniform prior.
@ -805,7 +805,7 @@ TEST(HybridGaussianFactorGraph, EliminateTiny1) {
X(0), Vector1(14.1421), I_1x1 * 2.82843), X(0), Vector1(14.1421), I_1x1 * 2.82843),
conditional1 = std::make_shared<GaussianConditional>( conditional1 = std::make_shared<GaussianConditional>(
X(0), Vector1(10.1379), I_1x1 * 2.02759); X(0), Vector1(10.1379), I_1x1 * 2.02759);
expectedBayesNet.emplace_shared<GaussianMixture>( expectedBayesNet.emplace_shared<HybridGaussianConditional>(
KeyVector{X(0)}, KeyVector{}, DiscreteKeys{mode}, KeyVector{X(0)}, KeyVector{}, DiscreteKeys{mode},
std::vector{conditional0, conditional1}); std::vector{conditional0, conditional1});
@ -830,7 +830,7 @@ TEST(HybridGaussianFactorGraph, EliminateTiny1Swapped) {
HybridBayesNet bn; HybridBayesNet bn;
// Create Gaussian mixture z_0 = x0 + noise for each measurement. // Create Gaussian mixture z_0 = x0 + noise for each measurement.
auto gm = std::make_shared<GaussianMixture>( auto gm = std::make_shared<HybridGaussianConditional>(
KeyVector{Z(0)}, KeyVector{X(0)}, DiscreteKeys{mode}, KeyVector{Z(0)}, KeyVector{X(0)}, DiscreteKeys{mode},
std::vector{ std::vector{
GaussianConditional::sharedMeanAndStddev(Z(0), I_1x1, X(0), Z_1x1, 3), GaussianConditional::sharedMeanAndStddev(Z(0), I_1x1, X(0), Z_1x1, 3),
@ -862,7 +862,7 @@ TEST(HybridGaussianFactorGraph, EliminateTiny1Swapped) {
X(0), Vector1(10.1379), I_1x1 * 2.02759), X(0), Vector1(10.1379), I_1x1 * 2.02759),
conditional1 = std::make_shared<GaussianConditional>( conditional1 = std::make_shared<GaussianConditional>(
X(0), Vector1(14.1421), I_1x1 * 2.82843); X(0), Vector1(14.1421), I_1x1 * 2.82843);
expectedBayesNet.emplace_shared<GaussianMixture>( expectedBayesNet.emplace_shared<HybridGaussianConditional>(
KeyVector{X(0)}, KeyVector{}, DiscreteKeys{mode}, KeyVector{X(0)}, KeyVector{}, DiscreteKeys{mode},
std::vector{conditional0, conditional1}); std::vector{conditional0, conditional1});
@ -899,7 +899,7 @@ TEST(HybridGaussianFactorGraph, EliminateTiny2) {
X(0), Vector1(17.3205), I_1x1 * 3.4641), X(0), Vector1(17.3205), I_1x1 * 3.4641),
conditional1 = std::make_shared<GaussianConditional>( conditional1 = std::make_shared<GaussianConditional>(
X(0), Vector1(10.274), I_1x1 * 2.0548); X(0), Vector1(10.274), I_1x1 * 2.0548);
expectedBayesNet.emplace_shared<GaussianMixture>( expectedBayesNet.emplace_shared<HybridGaussianConditional>(
KeyVector{X(0)}, KeyVector{}, DiscreteKeys{mode}, KeyVector{X(0)}, KeyVector{}, DiscreteKeys{mode},
std::vector{conditional0, conditional1}); std::vector{conditional0, conditional1});
@ -946,7 +946,7 @@ TEST(HybridGaussianFactorGraph, EliminateSwitchingNetwork) {
for (size_t t : {0, 1, 2}) { for (size_t t : {0, 1, 2}) {
// Create Gaussian mixture on Z(t) conditioned on X(t) and mode N(t): // Create Gaussian mixture on Z(t) conditioned on X(t) and mode N(t):
const auto noise_mode_t = DiscreteKey{N(t), 2}; const auto noise_mode_t = DiscreteKey{N(t), 2};
bn.emplace_shared<GaussianMixture>( bn.emplace_shared<HybridGaussianConditional>(
KeyVector{Z(t)}, KeyVector{X(t)}, DiscreteKeys{noise_mode_t}, KeyVector{Z(t)}, KeyVector{X(t)}, DiscreteKeys{noise_mode_t},
std::vector{GaussianConditional::sharedMeanAndStddev(Z(t), I_1x1, X(t), std::vector{GaussianConditional::sharedMeanAndStddev(Z(t), I_1x1, X(t),
Z_1x1, 0.5), Z_1x1, 0.5),
@ -961,7 +961,7 @@ TEST(HybridGaussianFactorGraph, EliminateSwitchingNetwork) {
for (size_t t : {2, 1}) { for (size_t t : {2, 1}) {
// Create Gaussian mixture on X(t) conditioned on X(t-1) and mode M(t-1): // Create Gaussian mixture on X(t) conditioned on X(t-1) and mode M(t-1):
const auto motion_model_t = DiscreteKey{M(t), 2}; const auto motion_model_t = DiscreteKey{M(t), 2};
auto gm = std::make_shared<GaussianMixture>( auto gm = std::make_shared<HybridGaussianConditional>(
KeyVector{X(t)}, KeyVector{X(t - 1)}, DiscreteKeys{motion_model_t}, KeyVector{X(t)}, KeyVector{X(t - 1)}, DiscreteKeys{motion_model_t},
std::vector{GaussianConditional::sharedMeanAndStddev( std::vector{GaussianConditional::sharedMeanAndStddev(
X(t), I_1x1, X(t - 1), Z_1x1, 0.2), X(t), I_1x1, X(t - 1), Z_1x1, 0.2),

View File

@ -134,22 +134,22 @@ TEST(HybridGaussianElimination, IncrementalInference) {
// The densities on X(0) should be the same // The densities on X(0) should be the same
auto x0_conditional = auto x0_conditional =
dynamic_pointer_cast<GaussianMixture>(isam[X(0)]->conditional()->inner()); dynamic_pointer_cast<HybridGaussianConditional>(isam[X(0)]->conditional()->inner());
auto expected_x0_conditional = dynamic_pointer_cast<GaussianMixture>( auto expected_x0_conditional = dynamic_pointer_cast<HybridGaussianConditional>(
(*expectedHybridBayesTree)[X(0)]->conditional()->inner()); (*expectedHybridBayesTree)[X(0)]->conditional()->inner());
EXPECT(assert_equal(*x0_conditional, *expected_x0_conditional)); EXPECT(assert_equal(*x0_conditional, *expected_x0_conditional));
// The densities on X(1) should be the same // The densities on X(1) should be the same
auto x1_conditional = auto x1_conditional =
dynamic_pointer_cast<GaussianMixture>(isam[X(1)]->conditional()->inner()); dynamic_pointer_cast<HybridGaussianConditional>(isam[X(1)]->conditional()->inner());
auto expected_x1_conditional = dynamic_pointer_cast<GaussianMixture>( auto expected_x1_conditional = dynamic_pointer_cast<HybridGaussianConditional>(
(*expectedHybridBayesTree)[X(1)]->conditional()->inner()); (*expectedHybridBayesTree)[X(1)]->conditional()->inner());
EXPECT(assert_equal(*x1_conditional, *expected_x1_conditional)); EXPECT(assert_equal(*x1_conditional, *expected_x1_conditional));
// The densities on X(2) should be the same // The densities on X(2) should be the same
auto x2_conditional = auto x2_conditional =
dynamic_pointer_cast<GaussianMixture>(isam[X(2)]->conditional()->inner()); dynamic_pointer_cast<HybridGaussianConditional>(isam[X(2)]->conditional()->inner());
auto expected_x2_conditional = dynamic_pointer_cast<GaussianMixture>( auto expected_x2_conditional = dynamic_pointer_cast<HybridGaussianConditional>(
(*expectedHybridBayesTree)[X(2)]->conditional()->inner()); (*expectedHybridBayesTree)[X(2)]->conditional()->inner());
EXPECT(assert_equal(*x2_conditional, *expected_x2_conditional)); EXPECT(assert_equal(*x2_conditional, *expected_x2_conditional));
@ -279,9 +279,9 @@ TEST(HybridGaussianElimination, Approx_inference) {
// Check that the hybrid nodes of the bayes net match those of the pre-pruning // Check that the hybrid nodes of the bayes net match those of the pre-pruning
// bayes net, at the same positions. // bayes net, at the same positions.
auto &unprunedLastDensity = *dynamic_pointer_cast<GaussianMixture>( auto &unprunedLastDensity = *dynamic_pointer_cast<HybridGaussianConditional>(
unprunedHybridBayesTree->clique(X(3))->conditional()->inner()); unprunedHybridBayesTree->clique(X(3))->conditional()->inner());
auto &lastDensity = *dynamic_pointer_cast<GaussianMixture>( auto &lastDensity = *dynamic_pointer_cast<HybridGaussianConditional>(
incrementalHybrid[X(3)]->conditional()->inner()); incrementalHybrid[X(3)]->conditional()->inner());
std::vector<std::pair<DiscreteValues, double>> assignments = std::vector<std::pair<DiscreteValues, double>> assignments =

View File

@ -350,7 +350,7 @@ TEST(HybridGaussianElimination, EliminateHybrid_2_Variable) {
EliminateHybrid(factors, ordering); EliminateHybrid(factors, ordering);
auto gaussianConditionalMixture = auto gaussianConditionalMixture =
dynamic_pointer_cast<GaussianMixture>(hybridConditionalMixture->inner()); dynamic_pointer_cast<HybridGaussianConditional>(hybridConditionalMixture->inner());
CHECK(gaussianConditionalMixture); CHECK(gaussianConditionalMixture);
// Frontals = [x0, x1] // Frontals = [x0, x1]

View File

@ -151,23 +151,23 @@ TEST(HybridNonlinearISAM, IncrementalInference) {
.BaseEliminateable::eliminatePartialMultifrontal(ordering); .BaseEliminateable::eliminatePartialMultifrontal(ordering);
// The densities on X(1) should be the same // The densities on X(1) should be the same
auto x0_conditional = dynamic_pointer_cast<GaussianMixture>( auto x0_conditional = dynamic_pointer_cast<HybridGaussianConditional>(
bayesTree[X(0)]->conditional()->inner()); bayesTree[X(0)]->conditional()->inner());
auto expected_x0_conditional = dynamic_pointer_cast<GaussianMixture>( auto expected_x0_conditional = dynamic_pointer_cast<HybridGaussianConditional>(
(*expectedHybridBayesTree)[X(0)]->conditional()->inner()); (*expectedHybridBayesTree)[X(0)]->conditional()->inner());
EXPECT(assert_equal(*x0_conditional, *expected_x0_conditional)); EXPECT(assert_equal(*x0_conditional, *expected_x0_conditional));
// The densities on X(1) should be the same // The densities on X(1) should be the same
auto x1_conditional = dynamic_pointer_cast<GaussianMixture>( auto x1_conditional = dynamic_pointer_cast<HybridGaussianConditional>(
bayesTree[X(1)]->conditional()->inner()); bayesTree[X(1)]->conditional()->inner());
auto expected_x1_conditional = dynamic_pointer_cast<GaussianMixture>( auto expected_x1_conditional = dynamic_pointer_cast<HybridGaussianConditional>(
(*expectedHybridBayesTree)[X(1)]->conditional()->inner()); (*expectedHybridBayesTree)[X(1)]->conditional()->inner());
EXPECT(assert_equal(*x1_conditional, *expected_x1_conditional)); EXPECT(assert_equal(*x1_conditional, *expected_x1_conditional));
// The densities on X(2) should be the same // The densities on X(2) should be the same
auto x2_conditional = dynamic_pointer_cast<GaussianMixture>( auto x2_conditional = dynamic_pointer_cast<HybridGaussianConditional>(
bayesTree[X(2)]->conditional()->inner()); bayesTree[X(2)]->conditional()->inner());
auto expected_x2_conditional = dynamic_pointer_cast<GaussianMixture>( auto expected_x2_conditional = dynamic_pointer_cast<HybridGaussianConditional>(
(*expectedHybridBayesTree)[X(2)]->conditional()->inner()); (*expectedHybridBayesTree)[X(2)]->conditional()->inner());
EXPECT(assert_equal(*x2_conditional, *expected_x2_conditional)); EXPECT(assert_equal(*x2_conditional, *expected_x2_conditional));
@ -300,9 +300,9 @@ TEST(HybridNonlinearISAM, Approx_inference) {
// Check that the hybrid nodes of the bayes net match those of the pre-pruning // Check that the hybrid nodes of the bayes net match those of the pre-pruning
// bayes net, at the same positions. // bayes net, at the same positions.
auto &unprunedLastDensity = *dynamic_pointer_cast<GaussianMixture>( auto &unprunedLastDensity = *dynamic_pointer_cast<HybridGaussianConditional>(
unprunedHybridBayesTree->clique(X(3))->conditional()->inner()); unprunedHybridBayesTree->clique(X(3))->conditional()->inner());
auto &lastDensity = *dynamic_pointer_cast<GaussianMixture>( auto &lastDensity = *dynamic_pointer_cast<HybridGaussianConditional>(
bayesTree[X(3)]->conditional()->inner()); bayesTree[X(3)]->conditional()->inner());
std::vector<std::pair<DiscreteValues, double>> assignments = std::vector<std::pair<DiscreteValues, double>> assignments =

View File

@ -18,7 +18,7 @@
#include <gtsam/base/serializationTestHelpers.h> #include <gtsam/base/serializationTestHelpers.h>
#include <gtsam/discrete/DiscreteConditional.h> #include <gtsam/discrete/DiscreteConditional.h>
#include <gtsam/hybrid/GaussianMixture.h> #include <gtsam/hybrid/HybridGaussianConditional.h>
#include <gtsam/hybrid/HybridGaussianFactor.h> #include <gtsam/hybrid/HybridGaussianFactor.h>
#include <gtsam/hybrid/HybridBayesNet.h> #include <gtsam/hybrid/HybridBayesNet.h>
#include <gtsam/hybrid/HybridBayesTree.h> #include <gtsam/hybrid/HybridBayesTree.h>
@ -59,12 +59,12 @@ BOOST_CLASS_EXPORT_GUID(HybridGaussianFactor::Factors::Leaf,
BOOST_CLASS_EXPORT_GUID(HybridGaussianFactor::Factors::Choice, BOOST_CLASS_EXPORT_GUID(HybridGaussianFactor::Factors::Choice,
"gtsam_GaussianMixtureFactor_Factors_Choice"); "gtsam_GaussianMixtureFactor_Factors_Choice");
BOOST_CLASS_EXPORT_GUID(GaussianMixture, "gtsam_GaussianMixture"); BOOST_CLASS_EXPORT_GUID(HybridGaussianConditional, "gtsam_GaussianMixture");
BOOST_CLASS_EXPORT_GUID(GaussianMixture::Conditionals, BOOST_CLASS_EXPORT_GUID(HybridGaussianConditional::Conditionals,
"gtsam_GaussianMixture_Conditionals"); "gtsam_GaussianMixture_Conditionals");
BOOST_CLASS_EXPORT_GUID(GaussianMixture::Conditionals::Leaf, BOOST_CLASS_EXPORT_GUID(HybridGaussianConditional::Conditionals::Leaf,
"gtsam_GaussianMixture_Conditionals_Leaf"); "gtsam_GaussianMixture_Conditionals_Leaf");
BOOST_CLASS_EXPORT_GUID(GaussianMixture::Conditionals::Choice, BOOST_CLASS_EXPORT_GUID(HybridGaussianConditional::Conditionals::Choice,
"gtsam_GaussianMixture_Conditionals_Choice"); "gtsam_GaussianMixture_Conditionals_Choice");
// Needed since GaussianConditional::FromMeanAndStddev uses it // Needed since GaussianConditional::FromMeanAndStddev uses it
BOOST_CLASS_EXPORT_GUID(noiseModel::Isotropic, "gtsam_noiseModel_Isotropic"); BOOST_CLASS_EXPORT_GUID(noiseModel::Isotropic, "gtsam_noiseModel_Isotropic");
@ -106,20 +106,20 @@ TEST(HybridSerialization, HybridConditional) {
} }
/* ****************************************************************************/ /* ****************************************************************************/
// Test GaussianMixture serialization. // Test HybridGaussianConditional serialization.
TEST(HybridSerialization, GaussianMixture) { TEST(HybridSerialization, HybridGaussianConditional) {
const DiscreteKey mode(M(0), 2); const DiscreteKey mode(M(0), 2);
Matrix1 I = Matrix1::Identity(); Matrix1 I = Matrix1::Identity();
const auto conditional0 = std::make_shared<GaussianConditional>( const auto conditional0 = std::make_shared<GaussianConditional>(
GaussianConditional::FromMeanAndStddev(Z(0), I, X(0), Vector1(0), 0.5)); GaussianConditional::FromMeanAndStddev(Z(0), I, X(0), Vector1(0), 0.5));
const auto conditional1 = std::make_shared<GaussianConditional>( const auto conditional1 = std::make_shared<GaussianConditional>(
GaussianConditional::FromMeanAndStddev(Z(0), I, X(0), Vector1(0), 3)); GaussianConditional::FromMeanAndStddev(Z(0), I, X(0), Vector1(0), 3));
const GaussianMixture gm({Z(0)}, {X(0)}, {mode}, const HybridGaussianConditional gm({Z(0)}, {X(0)}, {mode},
{conditional0, conditional1}); {conditional0, conditional1});
EXPECT(equalsObj<GaussianMixture>(gm)); EXPECT(equalsObj<HybridGaussianConditional>(gm));
EXPECT(equalsXML<GaussianMixture>(gm)); EXPECT(equalsXML<HybridGaussianConditional>(gm));
EXPECT(equalsBinary<GaussianMixture>(gm)); EXPECT(equalsBinary<HybridGaussianConditional>(gm));
} }
/* ****************************************************************************/ /* ****************************************************************************/

View File

@ -46,7 +46,7 @@ namespace gtsam {
* Gaussian density over a set of continuous variables. * Gaussian density over a set of continuous variables.
* - \b Discrete conditionals, implemented in \class DiscreteConditional, which * - \b Discrete conditionals, implemented in \class DiscreteConditional, which
* represent a discrete conditional distribution over discrete variables. * represent a discrete conditional distribution over discrete variables.
* - \b Hybrid conditional densities, such as \class GaussianMixture, which is * - \b Hybrid conditional densities, such as \class HybridGaussianConditional, which is
* a density over continuous variables given discrete/continuous parents. * a density over continuous variables given discrete/continuous parents.
* - \b Symbolic factors, used to represent a graph structure, implemented in * - \b Symbolic factors, used to represent a graph structure, implemented in
* \class SymbolicConditional. Only used for symbolic elimination etc. * \class SymbolicConditional. Only used for symbolic elimination etc.

View File

@ -18,8 +18,9 @@ from gtsam.symbol_shorthand import A, X
from gtsam.utils.test_case import GtsamTestCase from gtsam.utils.test_case import GtsamTestCase
from gtsam import (DiscreteConditional, DiscreteKeys, DiscreteValues, from gtsam import (DiscreteConditional, DiscreteKeys, DiscreteValues,
GaussianConditional, GaussianMixture, HybridBayesNet, GaussianConditional, HybridBayesNet,
HybridValues, VectorValues, noiseModel) HybridGaussianConditional, HybridValues, VectorValues,
noiseModel)
class TestHybridBayesNet(GtsamTestCase): class TestHybridBayesNet(GtsamTestCase):
@ -49,7 +50,7 @@ class TestHybridBayesNet(GtsamTestCase):
bayesNet = HybridBayesNet() bayesNet = HybridBayesNet()
bayesNet.push_back(conditional) bayesNet.push_back(conditional)
bayesNet.push_back( bayesNet.push_back(
GaussianMixture([X(1)], [], discrete_keys, HybridGaussianConditional([X(1)], [], discrete_keys,
[conditional0, conditional1])) [conditional0, conditional1]))
bayesNet.push_back(DiscreteConditional(Asia, "99/1")) bayesNet.push_back(DiscreteConditional(Asia, "99/1"))

View File

@ -18,9 +18,9 @@ from gtsam.utils.test_case import GtsamTestCase
import gtsam import gtsam
from gtsam import (DiscreteConditional, DiscreteKeys, GaussianConditional, from gtsam import (DiscreteConditional, DiscreteKeys, GaussianConditional,
GaussianMixture, HybridBayesNet, HybridGaussianFactor, HybridBayesNet, HybridGaussianConditional,
HybridGaussianFactorGraph, HybridValues, JacobianFactor, HybridGaussianFactor, HybridGaussianFactorGraph,
Ordering, noiseModel) HybridValues, JacobianFactor, Ordering, noiseModel)
DEBUG_MARGINALS = False DEBUG_MARGINALS = False
@ -48,7 +48,7 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
self.assertEqual(hbn.size(), 2) self.assertEqual(hbn.size(), 2)
mixture = hbn.at(0).inner() mixture = hbn.at(0).inner()
self.assertIsInstance(mixture, GaussianMixture) self.assertIsInstance(mixture, HybridGaussianConditional)
self.assertEqual(len(mixture.keys()), 2) self.assertEqual(len(mixture.keys()), 2)
discrete_conditional = hbn.at(hbn.size() - 1).inner() discrete_conditional = hbn.at(hbn.size() - 1).inner()
@ -106,7 +106,7 @@ class TestHybridGaussianFactorGraph(GtsamTestCase):
I_1x1, I_1x1,
X(0), [0], X(0), [0],
sigma=3) sigma=3)
bayesNet.push_back(GaussianMixture([Z(i)], [X(0)], keys, bayesNet.push_back(HybridGaussianConditional([Z(i)], [X(0)], keys,
[conditional0, conditional1])) [conditional0, conditional1]))
# Create prior on X(0). # Create prior on X(0).