improve GaussianMixture by checking for invalid conditionals and adding 2 new methods

release/4.3a0
Varun Agrawal 2024-08-20 16:28:28 -04:00
parent 2007ef53de
commit 598edfacce
3 changed files with 81 additions and 4 deletions

View File

@ -24,6 +24,7 @@
#include <gtsam/hybrid/GaussianMixtureFactor.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/inference/Conditional-inst.h>
#include <gtsam/linear/GaussianBayesNet.h>
#include <gtsam/linear/GaussianFactorGraph.h>
namespace gtsam {
@ -92,6 +93,35 @@ GaussianFactorGraphTree GaussianMixture::asGaussianFactorGraphTree() const {
return {conditionals_, wrap};
}
/*
*******************************************************************************/
GaussianBayesNetTree GaussianMixture::add(
const GaussianBayesNetTree &sum) const {
using Y = GaussianBayesNet;
auto add = [](const Y &graph1, const Y &graph2) {
auto result = graph1;
if (graph2.size() == 0) {
return GaussianBayesNet();
}
result.push_back(graph2);
return result;
};
const auto tree = asGaussianBayesNetTree();
return sum.empty() ? tree : sum.apply(tree, add);
}
/* *******************************************************************************/
GaussianBayesNetTree GaussianMixture::asGaussianBayesNetTree() const {
auto wrap = [](const GaussianConditional::shared_ptr &gc) {
if (gc) {
return GaussianBayesNet{gc};
} else {
return GaussianBayesNet();
}
};
return {conditionals_, wrap};
}
/* *******************************************************************************/
size_t GaussianMixture::nrComponents() const {
size_t total = 0;
@ -318,8 +348,15 @@ AlgebraicDecisionTree<Key> GaussianMixture::logProbability(
AlgebraicDecisionTree<Key> GaussianMixture::errorTree(
const VectorValues &continuousValues) const {
auto errorFunc = [&](const GaussianConditional::shared_ptr &conditional) {
return conditional->error(continuousValues) + //
logConstant_ - conditional->logNormalizationConstant();
// Check if valid pointer
if (conditional) {
return conditional->error(continuousValues) + //
logConstant_ - conditional->logNormalizationConstant();
} else {
// If not valid, pointer, it means this conditional was pruned,
// so we return maximum error.
return std::numeric_limits<double>::max();
}
};
DecisionTree<Key, double> error_tree(conditionals_, errorFunc);
return error_tree;
@ -327,10 +364,32 @@ AlgebraicDecisionTree<Key> GaussianMixture::errorTree(
/* *******************************************************************************/
double GaussianMixture::error(const HybridValues &values) const {
// Check if discrete keys in discrete assignment are
// present in the GaussianMixture
KeyVector dKeys = this->discreteKeys_.indices();
bool valid_assignment = false;
for (auto &&kv : values.discrete()) {
if (std::find(dKeys.begin(), dKeys.end(), kv.first) != dKeys.end()) {
valid_assignment = true;
break;
}
}
// The discrete assignment is not valid so we return 0.0 erorr.
if (!valid_assignment) {
return 0.0;
}
// Directly index to get the conditional, no need to build the whole tree.
auto conditional = conditionals_(values.discrete());
return conditional->error(values.continuous()) + //
logConstant_ - conditional->logNormalizationConstant();
if (conditional) {
return conditional->error(values.continuous()) + //
logConstant_ - conditional->logNormalizationConstant();
} else {
// If not valid, pointer, it means this conditional was pruned,
// so we return maximum error.
return std::numeric_limits<double>::max();
}
}
/* *******************************************************************************/

View File

@ -72,6 +72,12 @@ class GTSAM_EXPORT GaussianMixture
*/
GaussianFactorGraphTree asGaussianFactorGraphTree() const;
/**
* @brief Convert a DecisionTree of conditionals into
* a DecisionTree of Gaussian Bayes nets.
*/
GaussianBayesNetTree asGaussianBayesNetTree() const;
/**
* @brief Helper function to get the pruner functor.
*
@ -250,6 +256,15 @@ class GTSAM_EXPORT GaussianMixture
* @return GaussianFactorGraphTree
*/
GaussianFactorGraphTree add(const GaussianFactorGraphTree &sum) const;
/**
* @brief Merge the Gaussian Bayes Nets in `this` and `sum` while
* maintaining the decision tree structure.
*
* @param sum Decision Tree of Gaussian Bayes Nets
* @return GaussianBayesNetTree
*/
GaussianBayesNetTree add(const GaussianBayesNetTree &sum) const;
/// @}
private:

View File

@ -13,6 +13,7 @@
* @file HybridFactor.h
* @date Mar 11, 2022
* @author Fan Jiang
* @author Varun Agrawal
*/
#pragma once
@ -33,6 +34,8 @@ class HybridValues;
/// Alias for DecisionTree of GaussianFactorGraphs
using GaussianFactorGraphTree = DecisionTree<Key, GaussianFactorGraph>;
/// Alias for DecisionTree of GaussianBayesNets
using GaussianBayesNetTree = DecisionTree<Key, GaussianBayesNet>;
KeyVector CollectKeys(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys);