diff --git a/gtsam/hybrid/GaussianMixture.cpp b/gtsam/hybrid/GaussianMixture.cpp index 531a9f2e5..36a34226b 100644 --- a/gtsam/hybrid/GaussianMixture.cpp +++ b/gtsam/hybrid/GaussianMixture.cpp @@ -24,6 +24,7 @@ #include #include #include +#include #include namespace gtsam { @@ -92,6 +93,35 @@ GaussianFactorGraphTree GaussianMixture::asGaussianFactorGraphTree() const { return {conditionals_, wrap}; } +/* +*******************************************************************************/ +GaussianBayesNetTree GaussianMixture::add( + const GaussianBayesNetTree &sum) const { + using Y = GaussianBayesNet; + auto add = [](const Y &graph1, const Y &graph2) { + auto result = graph1; + if (graph2.size() == 0) { + return GaussianBayesNet(); + } + result.push_back(graph2); + return result; + }; + const auto tree = asGaussianBayesNetTree(); + return sum.empty() ? tree : sum.apply(tree, add); +} + +/* *******************************************************************************/ +GaussianBayesNetTree GaussianMixture::asGaussianBayesNetTree() const { + auto wrap = [](const GaussianConditional::shared_ptr &gc) { + if (gc) { + return GaussianBayesNet{gc}; + } else { + return GaussianBayesNet(); + } + }; + return {conditionals_, wrap}; +} + /* *******************************************************************************/ size_t GaussianMixture::nrComponents() const { size_t total = 0; @@ -318,8 +348,15 @@ AlgebraicDecisionTree GaussianMixture::logProbability( AlgebraicDecisionTree GaussianMixture::errorTree( const VectorValues &continuousValues) const { auto errorFunc = [&](const GaussianConditional::shared_ptr &conditional) { - return conditional->error(continuousValues) + // - logConstant_ - conditional->logNormalizationConstant(); + // Check if valid pointer + if (conditional) { + return conditional->error(continuousValues) + // + logConstant_ - conditional->logNormalizationConstant(); + } else { + // If not valid, pointer, it means this conditional was pruned, + // so we return maximum error. + return std::numeric_limits::max(); + } }; DecisionTree error_tree(conditionals_, errorFunc); return error_tree; @@ -327,10 +364,32 @@ AlgebraicDecisionTree GaussianMixture::errorTree( /* *******************************************************************************/ double GaussianMixture::error(const HybridValues &values) const { + // Check if discrete keys in discrete assignment are + // present in the GaussianMixture + KeyVector dKeys = this->discreteKeys_.indices(); + bool valid_assignment = false; + for (auto &&kv : values.discrete()) { + if (std::find(dKeys.begin(), dKeys.end(), kv.first) != dKeys.end()) { + valid_assignment = true; + break; + } + } + + // The discrete assignment is not valid so we return 0.0 erorr. + if (!valid_assignment) { + return 0.0; + } + // Directly index to get the conditional, no need to build the whole tree. auto conditional = conditionals_(values.discrete()); - return conditional->error(values.continuous()) + // - logConstant_ - conditional->logNormalizationConstant(); + if (conditional) { + return conditional->error(values.continuous()) + // + logConstant_ - conditional->logNormalizationConstant(); + } else { + // If not valid, pointer, it means this conditional was pruned, + // so we return maximum error. + return std::numeric_limits::max(); + } } /* *******************************************************************************/ diff --git a/gtsam/hybrid/GaussianMixture.h b/gtsam/hybrid/GaussianMixture.h index c1ef504f8..bfa342dcf 100644 --- a/gtsam/hybrid/GaussianMixture.h +++ b/gtsam/hybrid/GaussianMixture.h @@ -72,6 +72,12 @@ class GTSAM_EXPORT GaussianMixture */ GaussianFactorGraphTree asGaussianFactorGraphTree() const; + /** + * @brief Convert a DecisionTree of conditionals into + * a DecisionTree of Gaussian Bayes nets. + */ + GaussianBayesNetTree asGaussianBayesNetTree() const; + /** * @brief Helper function to get the pruner functor. * @@ -250,6 +256,15 @@ class GTSAM_EXPORT GaussianMixture * @return GaussianFactorGraphTree */ GaussianFactorGraphTree add(const GaussianFactorGraphTree &sum) const; + + /** + * @brief Merge the Gaussian Bayes Nets in `this` and `sum` while + * maintaining the decision tree structure. + * + * @param sum Decision Tree of Gaussian Bayes Nets + * @return GaussianBayesNetTree + */ + GaussianBayesNetTree add(const GaussianBayesNetTree &sum) const; /// @} private: diff --git a/gtsam/hybrid/HybridFactor.h b/gtsam/hybrid/HybridFactor.h index afd1c8032..418489d66 100644 --- a/gtsam/hybrid/HybridFactor.h +++ b/gtsam/hybrid/HybridFactor.h @@ -13,6 +13,7 @@ * @file HybridFactor.h * @date Mar 11, 2022 * @author Fan Jiang + * @author Varun Agrawal */ #pragma once @@ -33,6 +34,8 @@ class HybridValues; /// Alias for DecisionTree of GaussianFactorGraphs using GaussianFactorGraphTree = DecisionTree; +/// Alias for DecisionTree of GaussianBayesNets +using GaussianBayesNetTree = DecisionTree; KeyVector CollectKeys(const KeyVector &continuousKeys, const DiscreteKeys &discreteKeys);