helper methods in GaussianMixture for model selection

release/4.3a0
Varun Agrawal 2023-12-27 15:45:35 -05:00
parent b20d33d79e
commit 3a89653e91
2 changed files with 68 additions and 2 deletions

View File

@ -24,6 +24,7 @@
#include <gtsam/hybrid/GaussianMixtureFactor.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/inference/Conditional-inst.h>
#include <gtsam/linear/GaussianBayesNet.h>
#include <gtsam/linear/GaussianFactorGraph.h>
namespace gtsam {
@ -92,6 +93,34 @@ GaussianFactorGraphTree GaussianMixture::asGaussianFactorGraphTree() const {
return {conditionals_, wrap};
}
/* *******************************************************************************/
GaussianBayesNetTree GaussianMixture::add(
const GaussianBayesNetTree &sum) const {
using Y = GaussianBayesNet;
auto add = [](const Y &graph1, const Y &graph2) {
auto result = graph1;
if (graph2.size() == 0) {
return GaussianBayesNet();
}
result.push_back(graph2);
return result;
};
const auto tree = asGaussianBayesNetTree();
return sum.empty() ? tree : sum.apply(tree, add);
}
/* *******************************************************************************/
GaussianBayesNetTree GaussianMixture::asGaussianBayesNetTree() const {
auto wrap = [](const GaussianConditional::shared_ptr &gc) {
if (gc) {
return GaussianBayesNet{gc};
} else {
return GaussianBayesNet();
}
};
return {conditionals_, wrap};
}
/* *******************************************************************************/
size_t GaussianMixture::nrComponents() const {
size_t total = 0;
@ -332,10 +361,32 @@ AlgebraicDecisionTree<Key> GaussianMixture::error(
/* *******************************************************************************/
double GaussianMixture::error(const HybridValues &values) const {
// Check if discrete keys in discrete assignment are
// present in the GaussianMixture
KeyVector dKeys = this->discreteKeys_.indices();
bool valid_assignment = false;
for (auto &&kv : values.discrete()) {
if (std::find(dKeys.begin(), dKeys.end(), kv.first) != dKeys.end()) {
valid_assignment = true;
break;
}
}
// The discrete assignment is not valid so we return 0.0 erorr.
if (!valid_assignment) {
return 0.0;
}
// Directly index to get the conditional, no need to build the whole tree.
auto conditional = conditionals_(values.discrete());
if (conditional) {
return conditional->error(values.continuous()) + //
logConstant_ - conditional->logNormalizationConstant();
} else {
// If not valid, pointer, it means this conditional was pruned,
// so we return maximum error.
return std::numeric_limits<double>::max();
}
}
/* *******************************************************************************/

View File

@ -71,6 +71,12 @@ class GTSAM_EXPORT GaussianMixture
*/
GaussianFactorGraphTree asGaussianFactorGraphTree() const;
/**
* @brief Convert a DecisionTree of conditionals into
* a DT of Gaussian Bayes nets.
*/
GaussianBayesNetTree asGaussianBayesNetTree() const;
/**
* @brief Helper function to get the pruner functor.
*
@ -248,6 +254,15 @@ class GTSAM_EXPORT GaussianMixture
* @return GaussianFactorGraphTree
*/
GaussianFactorGraphTree add(const GaussianFactorGraphTree &sum) const;
/**
* @brief Merge the Gaussian Bayes Nets in `this` and `sum` while
* maintaining the decision tree structure.
*
* @param sum Decision Tree of Gaussian Bayes Nets
* @return GaussianBayesNetTree
*/
GaussianBayesNetTree add(const GaussianBayesNetTree &sum) const;
/// @}
private: