GraphAndConstant struct for hybrid elimination
parent
2e6f477569
commit
a23322350c
|
|
@ -53,11 +53,11 @@ GaussianMixture::GaussianMixture(
|
|||
/* *******************************************************************************/
|
||||
GaussianMixture::Sum GaussianMixture::add(
|
||||
const GaussianMixture::Sum &sum) const {
|
||||
using Y = GaussianFactorGraph;
|
||||
using Y = GaussianMixtureFactor::GraphAndConstant;
|
||||
auto add = [](const Y &graph1, const Y &graph2) {
|
||||
auto result = graph1;
|
||||
result.push_back(graph2);
|
||||
return result;
|
||||
auto result = graph1.graph;
|
||||
result.push_back(graph2.graph);
|
||||
return Y(result, graph1.constant + graph2.constant);
|
||||
};
|
||||
const Sum tree = asGaussianFactorGraphTree();
|
||||
return sum.empty() ? tree : sum.apply(tree, add);
|
||||
|
|
@ -65,10 +65,15 @@ GaussianMixture::Sum GaussianMixture::add(
|
|||
|
||||
/* *******************************************************************************/
|
||||
GaussianMixture::Sum GaussianMixture::asGaussianFactorGraphTree() const {
|
||||
auto lambda = [](const GaussianFactor::shared_ptr &factor) {
|
||||
auto lambda = [](const GaussianConditional::shared_ptr &conditional) {
|
||||
GaussianFactorGraph result;
|
||||
result.push_back(factor);
|
||||
return result;
|
||||
result.push_back(conditional);
|
||||
if (conditional) {
|
||||
return GaussianMixtureFactor::GraphAndConstant(
|
||||
result, conditional->logNormalizationConstant());
|
||||
} else {
|
||||
return GaussianMixtureFactor::GraphAndConstant(result, 0.0);
|
||||
}
|
||||
};
|
||||
return {conditionals_, lambda};
|
||||
}
|
||||
|
|
|
|||
|
|
@ -23,13 +23,13 @@
|
|||
#include <gtsam/discrete/DecisionTree.h>
|
||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||
#include <gtsam/discrete/DiscreteKey.h>
|
||||
#include <gtsam/hybrid/GaussianMixtureFactor.h>
|
||||
#include <gtsam/hybrid/HybridFactor.h>
|
||||
#include <gtsam/inference/Conditional.h>
|
||||
#include <gtsam/linear/GaussianConditional.h>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
class GaussianMixtureFactor;
|
||||
class HybridValues;
|
||||
|
||||
/**
|
||||
|
|
@ -60,7 +60,7 @@ class GTSAM_EXPORT GaussianMixture
|
|||
using BaseConditional = Conditional<HybridFactor, GaussianMixture>;
|
||||
|
||||
/// Alias for DecisionTree of GaussianFactorGraphs
|
||||
using Sum = DecisionTree<Key, GaussianFactorGraph>;
|
||||
using Sum = DecisionTree<Key, GaussianMixtureFactor::GraphAndConstant>;
|
||||
|
||||
/// typedef for Decision Tree of Gaussian Conditionals
|
||||
using Conditionals = DecisionTree<Key, GaussianConditional::shared_ptr>;
|
||||
|
|
|
|||
|
|
@ -90,11 +90,11 @@ const GaussianMixtureFactor::Mixture GaussianMixtureFactor::factors() const {
|
|||
/* *******************************************************************************/
|
||||
GaussianMixtureFactor::Sum GaussianMixtureFactor::add(
|
||||
const GaussianMixtureFactor::Sum &sum) const {
|
||||
using Y = GaussianFactorGraph;
|
||||
using Y = GaussianMixtureFactor::GraphAndConstant;
|
||||
auto add = [](const Y &graph1, const Y &graph2) {
|
||||
auto result = graph1;
|
||||
result.push_back(graph2);
|
||||
return result;
|
||||
auto result = graph1.graph;
|
||||
result.push_back(graph2.graph);
|
||||
return Y(result, graph1.constant + graph2.constant);
|
||||
};
|
||||
const Sum tree = asGaussianFactorGraphTree();
|
||||
return sum.empty() ? tree : sum.apply(tree, add);
|
||||
|
|
@ -106,7 +106,7 @@ GaussianMixtureFactor::Sum GaussianMixtureFactor::asGaussianFactorGraphTree()
|
|||
auto wrap = [](const FactorAndConstant &factor_z) {
|
||||
GaussianFactorGraph result;
|
||||
result.push_back(factor_z.factor);
|
||||
return result;
|
||||
return GaussianMixtureFactor::GraphAndConstant(result, factor_z.constant);
|
||||
};
|
||||
return {factors_, wrap};
|
||||
}
|
||||
|
|
|
|||
|
|
@ -25,10 +25,10 @@
|
|||
#include <gtsam/discrete/DiscreteKey.h>
|
||||
#include <gtsam/hybrid/HybridFactor.h>
|
||||
#include <gtsam/linear/GaussianFactor.h>
|
||||
#include <gtsam/linear/GaussianFactorGraph.h>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
class GaussianFactorGraph;
|
||||
class HybridValues;
|
||||
class DiscreteValues;
|
||||
class VectorValues;
|
||||
|
|
@ -50,7 +50,6 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
|
|||
using This = GaussianMixtureFactor;
|
||||
using shared_ptr = boost::shared_ptr<This>;
|
||||
|
||||
using Sum = DecisionTree<Key, GaussianFactorGraph>;
|
||||
using sharedFactor = boost::shared_ptr<GaussianFactor>;
|
||||
|
||||
/// Gaussian factor and log of normalizing constant.
|
||||
|
|
@ -60,8 +59,9 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
|
|||
|
||||
// Return error with constant correction.
|
||||
double error(const VectorValues &values) const {
|
||||
// Note minus sign: constant is log of normalization constant for probabilities.
|
||||
// Errors is the negative log-likelihood, hence we subtract the constant here.
|
||||
// Note: constant is log of normalization constant for probabilities.
|
||||
// Errors is the negative log-likelihood,
|
||||
// hence we subtract the constant here.
|
||||
return factor->error(values) - constant;
|
||||
}
|
||||
|
||||
|
|
@ -71,6 +71,22 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
|
|||
}
|
||||
};
|
||||
|
||||
/// Gaussian factor graph and log of normalizing constant.
|
||||
struct GraphAndConstant {
|
||||
GaussianFactorGraph graph;
|
||||
double constant;
|
||||
|
||||
GraphAndConstant(const GaussianFactorGraph &graph, double constant)
|
||||
: graph(graph), constant(constant) {}
|
||||
|
||||
// Check pointer equality.
|
||||
bool operator==(const GraphAndConstant &other) const {
|
||||
return graph == other.graph && constant == other.constant;
|
||||
}
|
||||
};
|
||||
|
||||
using Sum = DecisionTree<Key, GraphAndConstant>;
|
||||
|
||||
/// typedef for Decision Tree of Gaussian factors and log-constant.
|
||||
using Factors = DecisionTree<Key, FactorAndConstant>;
|
||||
using Mixture = DecisionTree<Key, sharedFactor>;
|
||||
|
|
|
|||
|
|
@ -61,18 +61,19 @@ template class EliminateableFactorGraph<HybridGaussianFactorGraph>;
|
|||
/* ************************************************************************ */
|
||||
static GaussianMixtureFactor::Sum &addGaussian(
|
||||
GaussianMixtureFactor::Sum &sum, const GaussianFactor::shared_ptr &factor) {
|
||||
using Y = GaussianFactorGraph;
|
||||
// If the decision tree is not initialized, then initialize it.
|
||||
if (sum.empty()) {
|
||||
GaussianFactorGraph result;
|
||||
result.push_back(factor);
|
||||
sum = GaussianMixtureFactor::Sum(result);
|
||||
sum = GaussianMixtureFactor::Sum(
|
||||
GaussianMixtureFactor::GraphAndConstant(result, 0.0));
|
||||
|
||||
} else {
|
||||
auto add = [&factor](const Y &graph) {
|
||||
auto result = graph;
|
||||
auto add = [&factor](
|
||||
const GaussianMixtureFactor::GraphAndConstant &graph_z) {
|
||||
auto result = graph_z.graph;
|
||||
result.push_back(factor);
|
||||
return result;
|
||||
return GaussianMixtureFactor::GraphAndConstant(result, graph_z.constant);
|
||||
};
|
||||
sum = sum.apply(add);
|
||||
}
|
||||
|
|
@ -190,31 +191,36 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
|
|||
DiscreteKeys discreteSeparator(discreteSeparatorSet.begin(),
|
||||
discreteSeparatorSet.end());
|
||||
|
||||
// sum out frontals, this is the factor 𝜏 on the separator
|
||||
// Collect all the frontal factors to create Gaussian factor graphs
|
||||
// indexed on the discrete keys.
|
||||
GaussianMixtureFactor::Sum sum = sumFrontals(factors);
|
||||
|
||||
// If a tree leaf contains nullptr,
|
||||
// convert that leaf to an empty GaussianFactorGraph.
|
||||
// Needed since the DecisionTree will otherwise create
|
||||
// a GFG with a single (null) factor.
|
||||
auto emptyGaussian = [](const GaussianFactorGraph &gfg) {
|
||||
bool hasNull =
|
||||
std::any_of(gfg.begin(), gfg.end(),
|
||||
[](const GaussianFactor::shared_ptr &ptr) { return !ptr; });
|
||||
auto emptyGaussian =
|
||||
[](const GaussianMixtureFactor::GraphAndConstant &graph_z) {
|
||||
bool hasNull = std::any_of(
|
||||
graph_z.graph.begin(), graph_z.graph.end(),
|
||||
[](const GaussianFactor::shared_ptr &ptr) { return !ptr; });
|
||||
|
||||
return hasNull ? GaussianFactorGraph() : gfg;
|
||||
};
|
||||
return hasNull ? GaussianMixtureFactor::GraphAndConstant(
|
||||
GaussianFactorGraph(), 0.0)
|
||||
: graph_z;
|
||||
};
|
||||
sum = GaussianMixtureFactor::Sum(sum, emptyGaussian);
|
||||
|
||||
using EliminationPair = std::pair<boost::shared_ptr<GaussianConditional>,
|
||||
GaussianMixtureFactor::FactorAndConstant>;
|
||||
|
||||
KeyVector keysOfEliminated; // Not the ordering
|
||||
KeyVector keysOfSeparator; // TODO(frank): Is this just (keys - ordering)?
|
||||
KeyVector keysOfSeparator;
|
||||
|
||||
// This is the elimination method on the leaf nodes
|
||||
auto eliminate = [&](const GaussianFactorGraph &graph) -> EliminationPair {
|
||||
if (graph.empty()) {
|
||||
auto eliminate = [&](const GaussianMixtureFactor::GraphAndConstant &graph_z)
|
||||
-> EliminationPair {
|
||||
if (graph_z.graph.empty()) {
|
||||
return {nullptr, {nullptr, 0.0}};
|
||||
}
|
||||
|
||||
|
|
@ -224,7 +230,8 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
|
|||
|
||||
std::pair<boost::shared_ptr<GaussianConditional>,
|
||||
boost::shared_ptr<GaussianFactor>>
|
||||
conditional_factor = EliminatePreferCholesky(graph, frontalKeys);
|
||||
conditional_factor =
|
||||
EliminatePreferCholesky(graph_z.graph, frontalKeys);
|
||||
|
||||
// Initialize the keysOfEliminated to be the keys of the
|
||||
// eliminated GaussianConditional
|
||||
|
|
@ -235,8 +242,10 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
|
|||
gttoc_(hybrid_eliminate);
|
||||
#endif
|
||||
|
||||
// TODO(Varun) The normalizing constant has to be computed correctly
|
||||
return {conditional_factor.first, {conditional_factor.second, 0.0}};
|
||||
GaussianConditional::shared_ptr conditional = conditional_factor.first;
|
||||
// Get the log of the log normalization constant inverse.
|
||||
double logZ = -conditional->logNormalizationConstant() + graph_z.constant;
|
||||
return {conditional, {conditional_factor.second, logZ}};
|
||||
};
|
||||
|
||||
// Perform elimination!
|
||||
|
|
@ -270,6 +279,13 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
|
|||
}
|
||||
};
|
||||
DecisionTree<Key, double> fdt(separatorFactors, factorProb);
|
||||
// Normalize the values of decision tree to be valid probabilities
|
||||
double sum = 0.0;
|
||||
auto visitor = [&](double y) { sum += y; };
|
||||
fdt.visit(visitor);
|
||||
// fdt = DecisionTree<Key, double>(fdt,
|
||||
// [sum](const double &x) { return x / sum;
|
||||
// });
|
||||
|
||||
auto discreteFactor =
|
||||
boost::make_shared<DecisionTreeFactor>(discreteSeparator, fdt);
|
||||
|
|
|
|||
Loading…
Reference in New Issue