helper methods in GaussianMixture for model selection
parent
b20d33d79e
commit
3a89653e91
|
|
@ -24,6 +24,7 @@
|
||||||
#include <gtsam/hybrid/GaussianMixtureFactor.h>
|
#include <gtsam/hybrid/GaussianMixtureFactor.h>
|
||||||
#include <gtsam/hybrid/HybridValues.h>
|
#include <gtsam/hybrid/HybridValues.h>
|
||||||
#include <gtsam/inference/Conditional-inst.h>
|
#include <gtsam/inference/Conditional-inst.h>
|
||||||
|
#include <gtsam/linear/GaussianBayesNet.h>
|
||||||
#include <gtsam/linear/GaussianFactorGraph.h>
|
#include <gtsam/linear/GaussianFactorGraph.h>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
@ -92,6 +93,34 @@ GaussianFactorGraphTree GaussianMixture::asGaussianFactorGraphTree() const {
|
||||||
return {conditionals_, wrap};
|
return {conditionals_, wrap};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/* *******************************************************************************/
|
||||||
|
GaussianBayesNetTree GaussianMixture::add(
|
||||||
|
const GaussianBayesNetTree &sum) const {
|
||||||
|
using Y = GaussianBayesNet;
|
||||||
|
auto add = [](const Y &graph1, const Y &graph2) {
|
||||||
|
auto result = graph1;
|
||||||
|
if (graph2.size() == 0) {
|
||||||
|
return GaussianBayesNet();
|
||||||
|
}
|
||||||
|
result.push_back(graph2);
|
||||||
|
return result;
|
||||||
|
};
|
||||||
|
const auto tree = asGaussianBayesNetTree();
|
||||||
|
return sum.empty() ? tree : sum.apply(tree, add);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* *******************************************************************************/
|
||||||
|
GaussianBayesNetTree GaussianMixture::asGaussianBayesNetTree() const {
|
||||||
|
auto wrap = [](const GaussianConditional::shared_ptr &gc) {
|
||||||
|
if (gc) {
|
||||||
|
return GaussianBayesNet{gc};
|
||||||
|
} else {
|
||||||
|
return GaussianBayesNet();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
return {conditionals_, wrap};
|
||||||
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
size_t GaussianMixture::nrComponents() const {
|
size_t GaussianMixture::nrComponents() const {
|
||||||
size_t total = 0;
|
size_t total = 0;
|
||||||
|
|
@ -332,10 +361,32 @@ AlgebraicDecisionTree<Key> GaussianMixture::error(
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
double GaussianMixture::error(const HybridValues &values) const {
|
double GaussianMixture::error(const HybridValues &values) const {
|
||||||
|
// Check if discrete keys in discrete assignment are
|
||||||
|
// present in the GaussianMixture
|
||||||
|
KeyVector dKeys = this->discreteKeys_.indices();
|
||||||
|
bool valid_assignment = false;
|
||||||
|
for (auto &&kv : values.discrete()) {
|
||||||
|
if (std::find(dKeys.begin(), dKeys.end(), kv.first) != dKeys.end()) {
|
||||||
|
valid_assignment = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// The discrete assignment is not valid so we return 0.0 erorr.
|
||||||
|
if (!valid_assignment) {
|
||||||
|
return 0.0;
|
||||||
|
}
|
||||||
|
|
||||||
// Directly index to get the conditional, no need to build the whole tree.
|
// Directly index to get the conditional, no need to build the whole tree.
|
||||||
auto conditional = conditionals_(values.discrete());
|
auto conditional = conditionals_(values.discrete());
|
||||||
|
if (conditional) {
|
||||||
return conditional->error(values.continuous()) + //
|
return conditional->error(values.continuous()) + //
|
||||||
logConstant_ - conditional->logNormalizationConstant();
|
logConstant_ - conditional->logNormalizationConstant();
|
||||||
|
} else {
|
||||||
|
// If not valid, pointer, it means this conditional was pruned,
|
||||||
|
// so we return maximum error.
|
||||||
|
return std::numeric_limits<double>::max();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
|
|
|
||||||
|
|
@ -71,6 +71,12 @@ class GTSAM_EXPORT GaussianMixture
|
||||||
*/
|
*/
|
||||||
GaussianFactorGraphTree asGaussianFactorGraphTree() const;
|
GaussianFactorGraphTree asGaussianFactorGraphTree() const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Convert a DecisionTree of conditionals into
|
||||||
|
* a DT of Gaussian Bayes nets.
|
||||||
|
*/
|
||||||
|
GaussianBayesNetTree asGaussianBayesNetTree() const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Helper function to get the pruner functor.
|
* @brief Helper function to get the pruner functor.
|
||||||
*
|
*
|
||||||
|
|
@ -248,6 +254,15 @@ class GTSAM_EXPORT GaussianMixture
|
||||||
* @return GaussianFactorGraphTree
|
* @return GaussianFactorGraphTree
|
||||||
*/
|
*/
|
||||||
GaussianFactorGraphTree add(const GaussianFactorGraphTree &sum) const;
|
GaussianFactorGraphTree add(const GaussianFactorGraphTree &sum) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Merge the Gaussian Bayes Nets in `this` and `sum` while
|
||||||
|
* maintaining the decision tree structure.
|
||||||
|
*
|
||||||
|
* @param sum Decision Tree of Gaussian Bayes Nets
|
||||||
|
* @return GaussianBayesNetTree
|
||||||
|
*/
|
||||||
|
GaussianBayesNetTree add(const GaussianBayesNetTree &sum) const;
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue