helper methods in GaussianMixture for model selection
parent
b20d33d79e
commit
3a89653e91
|
|
@ -24,6 +24,7 @@
|
|||
#include <gtsam/hybrid/GaussianMixtureFactor.h>
|
||||
#include <gtsam/hybrid/HybridValues.h>
|
||||
#include <gtsam/inference/Conditional-inst.h>
|
||||
#include <gtsam/linear/GaussianBayesNet.h>
|
||||
#include <gtsam/linear/GaussianFactorGraph.h>
|
||||
|
||||
namespace gtsam {
|
||||
|
|
@ -92,6 +93,34 @@ GaussianFactorGraphTree GaussianMixture::asGaussianFactorGraphTree() const {
|
|||
return {conditionals_, wrap};
|
||||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
GaussianBayesNetTree GaussianMixture::add(
|
||||
const GaussianBayesNetTree &sum) const {
|
||||
using Y = GaussianBayesNet;
|
||||
auto add = [](const Y &graph1, const Y &graph2) {
|
||||
auto result = graph1;
|
||||
if (graph2.size() == 0) {
|
||||
return GaussianBayesNet();
|
||||
}
|
||||
result.push_back(graph2);
|
||||
return result;
|
||||
};
|
||||
const auto tree = asGaussianBayesNetTree();
|
||||
return sum.empty() ? tree : sum.apply(tree, add);
|
||||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
GaussianBayesNetTree GaussianMixture::asGaussianBayesNetTree() const {
|
||||
auto wrap = [](const GaussianConditional::shared_ptr &gc) {
|
||||
if (gc) {
|
||||
return GaussianBayesNet{gc};
|
||||
} else {
|
||||
return GaussianBayesNet();
|
||||
}
|
||||
};
|
||||
return {conditionals_, wrap};
|
||||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
size_t GaussianMixture::nrComponents() const {
|
||||
size_t total = 0;
|
||||
|
|
@ -332,10 +361,32 @@ AlgebraicDecisionTree<Key> GaussianMixture::error(
|
|||
|
||||
/* *******************************************************************************/
|
||||
double GaussianMixture::error(const HybridValues &values) const {
|
||||
// Check if discrete keys in discrete assignment are
|
||||
// present in the GaussianMixture
|
||||
KeyVector dKeys = this->discreteKeys_.indices();
|
||||
bool valid_assignment = false;
|
||||
for (auto &&kv : values.discrete()) {
|
||||
if (std::find(dKeys.begin(), dKeys.end(), kv.first) != dKeys.end()) {
|
||||
valid_assignment = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// The discrete assignment is not valid so we return 0.0 erorr.
|
||||
if (!valid_assignment) {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
// Directly index to get the conditional, no need to build the whole tree.
|
||||
auto conditional = conditionals_(values.discrete());
|
||||
return conditional->error(values.continuous()) + //
|
||||
logConstant_ - conditional->logNormalizationConstant();
|
||||
if (conditional) {
|
||||
return conditional->error(values.continuous()) + //
|
||||
logConstant_ - conditional->logNormalizationConstant();
|
||||
} else {
|
||||
// If not valid, pointer, it means this conditional was pruned,
|
||||
// so we return maximum error.
|
||||
return std::numeric_limits<double>::max();
|
||||
}
|
||||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
|
|
|
|||
|
|
@ -71,6 +71,12 @@ class GTSAM_EXPORT GaussianMixture
|
|||
*/
|
||||
GaussianFactorGraphTree asGaussianFactorGraphTree() const;
|
||||
|
||||
/**
|
||||
* @brief Convert a DecisionTree of conditionals into
|
||||
* a DT of Gaussian Bayes nets.
|
||||
*/
|
||||
GaussianBayesNetTree asGaussianBayesNetTree() const;
|
||||
|
||||
/**
|
||||
* @brief Helper function to get the pruner functor.
|
||||
*
|
||||
|
|
@ -248,6 +254,15 @@ class GTSAM_EXPORT GaussianMixture
|
|||
* @return GaussianFactorGraphTree
|
||||
*/
|
||||
GaussianFactorGraphTree add(const GaussianFactorGraphTree &sum) const;
|
||||
|
||||
/**
|
||||
* @brief Merge the Gaussian Bayes Nets in `this` and `sum` while
|
||||
* maintaining the decision tree structure.
|
||||
*
|
||||
* @param sum Decision Tree of Gaussian Bayes Nets
|
||||
* @return GaussianBayesNetTree
|
||||
*/
|
||||
GaussianBayesNetTree add(const GaussianBayesNetTree &sum) const;
|
||||
/// @}
|
||||
|
||||
private:
|
||||
|
|
|
|||
Loading…
Reference in New Issue