diff --git a/gtsam/hybrid/GaussianMixture.cpp b/gtsam/hybrid/GaussianMixture.cpp index a1ba167f5..61b40e566 100644 --- a/gtsam/hybrid/GaussianMixture.cpp +++ b/gtsam/hybrid/GaussianMixture.cpp @@ -24,6 +24,7 @@ #include #include #include +#include #include namespace gtsam { @@ -92,6 +93,34 @@ GaussianFactorGraphTree GaussianMixture::asGaussianFactorGraphTree() const { return {conditionals_, wrap}; } +/* *******************************************************************************/ +GaussianBayesNetTree GaussianMixture::add( + const GaussianBayesNetTree &sum) const { + using Y = GaussianBayesNet; + auto add = [](const Y &graph1, const Y &graph2) { + auto result = graph1; + if (graph2.size() == 0) { + return GaussianBayesNet(); + } + result.push_back(graph2); + return result; + }; + const auto tree = asGaussianBayesNetTree(); + return sum.empty() ? tree : sum.apply(tree, add); +} + +/* *******************************************************************************/ +GaussianBayesNetTree GaussianMixture::asGaussianBayesNetTree() const { + auto wrap = [](const GaussianConditional::shared_ptr &gc) { + if (gc) { + return GaussianBayesNet{gc}; + } else { + return GaussianBayesNet(); + } + }; + return {conditionals_, wrap}; +} + /* *******************************************************************************/ size_t GaussianMixture::nrComponents() const { size_t total = 0; @@ -332,10 +361,32 @@ AlgebraicDecisionTree GaussianMixture::error( /* *******************************************************************************/ double GaussianMixture::error(const HybridValues &values) const { + // Check if discrete keys in discrete assignment are + // present in the GaussianMixture + KeyVector dKeys = this->discreteKeys_.indices(); + bool valid_assignment = false; + for (auto &&kv : values.discrete()) { + if (std::find(dKeys.begin(), dKeys.end(), kv.first) != dKeys.end()) { + valid_assignment = true; + break; + } + } + + // The discrete assignment is not valid so we return 0.0 erorr. + if (!valid_assignment) { + return 0.0; + } + // Directly index to get the conditional, no need to build the whole tree. auto conditional = conditionals_(values.discrete()); - return conditional->error(values.continuous()) + // - logConstant_ - conditional->logNormalizationConstant(); + if (conditional) { + return conditional->error(values.continuous()) + // + logConstant_ - conditional->logNormalizationConstant(); + } else { + // If not valid, pointer, it means this conditional was pruned, + // so we return maximum error. + return std::numeric_limits::max(); + } } /* *******************************************************************************/ diff --git a/gtsam/hybrid/GaussianMixture.h b/gtsam/hybrid/GaussianMixture.h index 0b68fcfd0..a52cd6c82 100644 --- a/gtsam/hybrid/GaussianMixture.h +++ b/gtsam/hybrid/GaussianMixture.h @@ -71,6 +71,12 @@ class GTSAM_EXPORT GaussianMixture */ GaussianFactorGraphTree asGaussianFactorGraphTree() const; + /** + * @brief Convert a DecisionTree of conditionals into + * a DT of Gaussian Bayes nets. + */ + GaussianBayesNetTree asGaussianBayesNetTree() const; + /** * @brief Helper function to get the pruner functor. * @@ -248,6 +254,15 @@ class GTSAM_EXPORT GaussianMixture * @return GaussianFactorGraphTree */ GaussianFactorGraphTree add(const GaussianFactorGraphTree &sum) const; + + /** + * @brief Merge the Gaussian Bayes Nets in `this` and `sum` while + * maintaining the decision tree structure. + * + * @param sum Decision Tree of Gaussian Bayes Nets + * @return GaussianBayesNetTree + */ + GaussianBayesNetTree add(const GaussianBayesNetTree &sum) const; /// @} private: