improve GaussianMixture by checking for invalid conditionals and adding 2 new methods
parent
2007ef53de
commit
598edfacce
|
|
@ -24,6 +24,7 @@
|
||||||
#include <gtsam/hybrid/GaussianMixtureFactor.h>
|
#include <gtsam/hybrid/GaussianMixtureFactor.h>
|
||||||
#include <gtsam/hybrid/HybridValues.h>
|
#include <gtsam/hybrid/HybridValues.h>
|
||||||
#include <gtsam/inference/Conditional-inst.h>
|
#include <gtsam/inference/Conditional-inst.h>
|
||||||
|
#include <gtsam/linear/GaussianBayesNet.h>
|
||||||
#include <gtsam/linear/GaussianFactorGraph.h>
|
#include <gtsam/linear/GaussianFactorGraph.h>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
@ -92,6 +93,35 @@ GaussianFactorGraphTree GaussianMixture::asGaussianFactorGraphTree() const {
|
||||||
return {conditionals_, wrap};
|
return {conditionals_, wrap};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
*******************************************************************************/
|
||||||
|
GaussianBayesNetTree GaussianMixture::add(
|
||||||
|
const GaussianBayesNetTree &sum) const {
|
||||||
|
using Y = GaussianBayesNet;
|
||||||
|
auto add = [](const Y &graph1, const Y &graph2) {
|
||||||
|
auto result = graph1;
|
||||||
|
if (graph2.size() == 0) {
|
||||||
|
return GaussianBayesNet();
|
||||||
|
}
|
||||||
|
result.push_back(graph2);
|
||||||
|
return result;
|
||||||
|
};
|
||||||
|
const auto tree = asGaussianBayesNetTree();
|
||||||
|
return sum.empty() ? tree : sum.apply(tree, add);
|
||||||
|
}
|
||||||
|
|
||||||
|
/* *******************************************************************************/
|
||||||
|
GaussianBayesNetTree GaussianMixture::asGaussianBayesNetTree() const {
|
||||||
|
auto wrap = [](const GaussianConditional::shared_ptr &gc) {
|
||||||
|
if (gc) {
|
||||||
|
return GaussianBayesNet{gc};
|
||||||
|
} else {
|
||||||
|
return GaussianBayesNet();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
return {conditionals_, wrap};
|
||||||
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
size_t GaussianMixture::nrComponents() const {
|
size_t GaussianMixture::nrComponents() const {
|
||||||
size_t total = 0;
|
size_t total = 0;
|
||||||
|
|
@ -318,8 +348,15 @@ AlgebraicDecisionTree<Key> GaussianMixture::logProbability(
|
||||||
AlgebraicDecisionTree<Key> GaussianMixture::errorTree(
|
AlgebraicDecisionTree<Key> GaussianMixture::errorTree(
|
||||||
const VectorValues &continuousValues) const {
|
const VectorValues &continuousValues) const {
|
||||||
auto errorFunc = [&](const GaussianConditional::shared_ptr &conditional) {
|
auto errorFunc = [&](const GaussianConditional::shared_ptr &conditional) {
|
||||||
return conditional->error(continuousValues) + //
|
// Check if valid pointer
|
||||||
logConstant_ - conditional->logNormalizationConstant();
|
if (conditional) {
|
||||||
|
return conditional->error(continuousValues) + //
|
||||||
|
logConstant_ - conditional->logNormalizationConstant();
|
||||||
|
} else {
|
||||||
|
// If not valid, pointer, it means this conditional was pruned,
|
||||||
|
// so we return maximum error.
|
||||||
|
return std::numeric_limits<double>::max();
|
||||||
|
}
|
||||||
};
|
};
|
||||||
DecisionTree<Key, double> error_tree(conditionals_, errorFunc);
|
DecisionTree<Key, double> error_tree(conditionals_, errorFunc);
|
||||||
return error_tree;
|
return error_tree;
|
||||||
|
|
@ -327,10 +364,32 @@ AlgebraicDecisionTree<Key> GaussianMixture::errorTree(
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
double GaussianMixture::error(const HybridValues &values) const {
|
double GaussianMixture::error(const HybridValues &values) const {
|
||||||
|
// Check if discrete keys in discrete assignment are
|
||||||
|
// present in the GaussianMixture
|
||||||
|
KeyVector dKeys = this->discreteKeys_.indices();
|
||||||
|
bool valid_assignment = false;
|
||||||
|
for (auto &&kv : values.discrete()) {
|
||||||
|
if (std::find(dKeys.begin(), dKeys.end(), kv.first) != dKeys.end()) {
|
||||||
|
valid_assignment = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// The discrete assignment is not valid so we return 0.0 erorr.
|
||||||
|
if (!valid_assignment) {
|
||||||
|
return 0.0;
|
||||||
|
}
|
||||||
|
|
||||||
// Directly index to get the conditional, no need to build the whole tree.
|
// Directly index to get the conditional, no need to build the whole tree.
|
||||||
auto conditional = conditionals_(values.discrete());
|
auto conditional = conditionals_(values.discrete());
|
||||||
return conditional->error(values.continuous()) + //
|
if (conditional) {
|
||||||
logConstant_ - conditional->logNormalizationConstant();
|
return conditional->error(values.continuous()) + //
|
||||||
|
logConstant_ - conditional->logNormalizationConstant();
|
||||||
|
} else {
|
||||||
|
// If not valid, pointer, it means this conditional was pruned,
|
||||||
|
// so we return maximum error.
|
||||||
|
return std::numeric_limits<double>::max();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
|
|
|
||||||
|
|
@ -72,6 +72,12 @@ class GTSAM_EXPORT GaussianMixture
|
||||||
*/
|
*/
|
||||||
GaussianFactorGraphTree asGaussianFactorGraphTree() const;
|
GaussianFactorGraphTree asGaussianFactorGraphTree() const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Convert a DecisionTree of conditionals into
|
||||||
|
* a DecisionTree of Gaussian Bayes nets.
|
||||||
|
*/
|
||||||
|
GaussianBayesNetTree asGaussianBayesNetTree() const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Helper function to get the pruner functor.
|
* @brief Helper function to get the pruner functor.
|
||||||
*
|
*
|
||||||
|
|
@ -250,6 +256,15 @@ class GTSAM_EXPORT GaussianMixture
|
||||||
* @return GaussianFactorGraphTree
|
* @return GaussianFactorGraphTree
|
||||||
*/
|
*/
|
||||||
GaussianFactorGraphTree add(const GaussianFactorGraphTree &sum) const;
|
GaussianFactorGraphTree add(const GaussianFactorGraphTree &sum) const;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Merge the Gaussian Bayes Nets in `this` and `sum` while
|
||||||
|
* maintaining the decision tree structure.
|
||||||
|
*
|
||||||
|
* @param sum Decision Tree of Gaussian Bayes Nets
|
||||||
|
* @return GaussianBayesNetTree
|
||||||
|
*/
|
||||||
|
GaussianBayesNetTree add(const GaussianBayesNetTree &sum) const;
|
||||||
/// @}
|
/// @}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,7 @@
|
||||||
* @file HybridFactor.h
|
* @file HybridFactor.h
|
||||||
* @date Mar 11, 2022
|
* @date Mar 11, 2022
|
||||||
* @author Fan Jiang
|
* @author Fan Jiang
|
||||||
|
* @author Varun Agrawal
|
||||||
*/
|
*/
|
||||||
|
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
@ -33,6 +34,8 @@ class HybridValues;
|
||||||
|
|
||||||
/// Alias for DecisionTree of GaussianFactorGraphs
|
/// Alias for DecisionTree of GaussianFactorGraphs
|
||||||
using GaussianFactorGraphTree = DecisionTree<Key, GaussianFactorGraph>;
|
using GaussianFactorGraphTree = DecisionTree<Key, GaussianFactorGraph>;
|
||||||
|
/// Alias for DecisionTree of GaussianBayesNets
|
||||||
|
using GaussianBayesNetTree = DecisionTree<Key, GaussianBayesNet>;
|
||||||
|
|
||||||
KeyVector CollectKeys(const KeyVector &continuousKeys,
|
KeyVector CollectKeys(const KeyVector &continuousKeys,
|
||||||
const DiscreteKeys &discreteKeys);
|
const DiscreteKeys &discreteKeys);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue