improve GaussianMixture by checking for invalid conditionals and adding 2 new methods
parent
2007ef53de
commit
598edfacce
|
|
@ -24,6 +24,7 @@
|
|||
#include <gtsam/hybrid/GaussianMixtureFactor.h>
|
||||
#include <gtsam/hybrid/HybridValues.h>
|
||||
#include <gtsam/inference/Conditional-inst.h>
|
||||
#include <gtsam/linear/GaussianBayesNet.h>
|
||||
#include <gtsam/linear/GaussianFactorGraph.h>
|
||||
|
||||
namespace gtsam {
|
||||
|
|
@ -92,6 +93,35 @@ GaussianFactorGraphTree GaussianMixture::asGaussianFactorGraphTree() const {
|
|||
return {conditionals_, wrap};
|
||||
}
|
||||
|
||||
/*
|
||||
*******************************************************************************/
|
||||
GaussianBayesNetTree GaussianMixture::add(
|
||||
const GaussianBayesNetTree &sum) const {
|
||||
using Y = GaussianBayesNet;
|
||||
auto add = [](const Y &graph1, const Y &graph2) {
|
||||
auto result = graph1;
|
||||
if (graph2.size() == 0) {
|
||||
return GaussianBayesNet();
|
||||
}
|
||||
result.push_back(graph2);
|
||||
return result;
|
||||
};
|
||||
const auto tree = asGaussianBayesNetTree();
|
||||
return sum.empty() ? tree : sum.apply(tree, add);
|
||||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
GaussianBayesNetTree GaussianMixture::asGaussianBayesNetTree() const {
|
||||
auto wrap = [](const GaussianConditional::shared_ptr &gc) {
|
||||
if (gc) {
|
||||
return GaussianBayesNet{gc};
|
||||
} else {
|
||||
return GaussianBayesNet();
|
||||
}
|
||||
};
|
||||
return {conditionals_, wrap};
|
||||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
size_t GaussianMixture::nrComponents() const {
|
||||
size_t total = 0;
|
||||
|
|
@ -318,8 +348,15 @@ AlgebraicDecisionTree<Key> GaussianMixture::logProbability(
|
|||
AlgebraicDecisionTree<Key> GaussianMixture::errorTree(
|
||||
const VectorValues &continuousValues) const {
|
||||
auto errorFunc = [&](const GaussianConditional::shared_ptr &conditional) {
|
||||
return conditional->error(continuousValues) + //
|
||||
logConstant_ - conditional->logNormalizationConstant();
|
||||
// Check if valid pointer
|
||||
if (conditional) {
|
||||
return conditional->error(continuousValues) + //
|
||||
logConstant_ - conditional->logNormalizationConstant();
|
||||
} else {
|
||||
// If not valid, pointer, it means this conditional was pruned,
|
||||
// so we return maximum error.
|
||||
return std::numeric_limits<double>::max();
|
||||
}
|
||||
};
|
||||
DecisionTree<Key, double> error_tree(conditionals_, errorFunc);
|
||||
return error_tree;
|
||||
|
|
@ -327,10 +364,32 @@ AlgebraicDecisionTree<Key> GaussianMixture::errorTree(
|
|||
|
||||
/* *******************************************************************************/
|
||||
double GaussianMixture::error(const HybridValues &values) const {
|
||||
// Check if discrete keys in discrete assignment are
|
||||
// present in the GaussianMixture
|
||||
KeyVector dKeys = this->discreteKeys_.indices();
|
||||
bool valid_assignment = false;
|
||||
for (auto &&kv : values.discrete()) {
|
||||
if (std::find(dKeys.begin(), dKeys.end(), kv.first) != dKeys.end()) {
|
||||
valid_assignment = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// The discrete assignment is not valid so we return 0.0 erorr.
|
||||
if (!valid_assignment) {
|
||||
return 0.0;
|
||||
}
|
||||
|
||||
// Directly index to get the conditional, no need to build the whole tree.
|
||||
auto conditional = conditionals_(values.discrete());
|
||||
return conditional->error(values.continuous()) + //
|
||||
logConstant_ - conditional->logNormalizationConstant();
|
||||
if (conditional) {
|
||||
return conditional->error(values.continuous()) + //
|
||||
logConstant_ - conditional->logNormalizationConstant();
|
||||
} else {
|
||||
// If not valid, pointer, it means this conditional was pruned,
|
||||
// so we return maximum error.
|
||||
return std::numeric_limits<double>::max();
|
||||
}
|
||||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
|
|
|
|||
|
|
@ -72,6 +72,12 @@ class GTSAM_EXPORT GaussianMixture
|
|||
*/
|
||||
GaussianFactorGraphTree asGaussianFactorGraphTree() const;
|
||||
|
||||
/**
|
||||
* @brief Convert a DecisionTree of conditionals into
|
||||
* a DecisionTree of Gaussian Bayes nets.
|
||||
*/
|
||||
GaussianBayesNetTree asGaussianBayesNetTree() const;
|
||||
|
||||
/**
|
||||
* @brief Helper function to get the pruner functor.
|
||||
*
|
||||
|
|
@ -250,6 +256,15 @@ class GTSAM_EXPORT GaussianMixture
|
|||
* @return GaussianFactorGraphTree
|
||||
*/
|
||||
GaussianFactorGraphTree add(const GaussianFactorGraphTree &sum) const;
|
||||
|
||||
/**
|
||||
* @brief Merge the Gaussian Bayes Nets in `this` and `sum` while
|
||||
* maintaining the decision tree structure.
|
||||
*
|
||||
* @param sum Decision Tree of Gaussian Bayes Nets
|
||||
* @return GaussianBayesNetTree
|
||||
*/
|
||||
GaussianBayesNetTree add(const GaussianBayesNetTree &sum) const;
|
||||
/// @}
|
||||
|
||||
private:
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@
|
|||
* @file HybridFactor.h
|
||||
* @date Mar 11, 2022
|
||||
* @author Fan Jiang
|
||||
* @author Varun Agrawal
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
|
@ -33,6 +34,8 @@ class HybridValues;
|
|||
|
||||
/// Alias for DecisionTree of GaussianFactorGraphs
|
||||
using GaussianFactorGraphTree = DecisionTree<Key, GaussianFactorGraph>;
|
||||
/// Alias for DecisionTree of GaussianBayesNets
|
||||
using GaussianBayesNetTree = DecisionTree<Key, GaussianBayesNet>;
|
||||
|
||||
KeyVector CollectKeys(const KeyVector &continuousKeys,
|
||||
const DiscreteKeys &discreteKeys);
|
||||
|
|
|
|||
Loading…
Reference in New Issue