GraphAndConstant struct for hybrid elimination

release/4.3a0
Varun Agrawal 2023-01-01 06:36:48 -05:00
parent 2e6f477569
commit a23322350c
5 changed files with 73 additions and 36 deletions

View File

@ -53,11 +53,11 @@ GaussianMixture::GaussianMixture(
/* *******************************************************************************/
GaussianMixture::Sum GaussianMixture::add(
const GaussianMixture::Sum &sum) const {
using Y = GaussianFactorGraph;
using Y = GaussianMixtureFactor::GraphAndConstant;
auto add = [](const Y &graph1, const Y &graph2) {
auto result = graph1;
result.push_back(graph2);
return result;
auto result = graph1.graph;
result.push_back(graph2.graph);
return Y(result, graph1.constant + graph2.constant);
};
const Sum tree = asGaussianFactorGraphTree();
return sum.empty() ? tree : sum.apply(tree, add);
@ -65,10 +65,15 @@ GaussianMixture::Sum GaussianMixture::add(
/* *******************************************************************************/
GaussianMixture::Sum GaussianMixture::asGaussianFactorGraphTree() const {
auto lambda = [](const GaussianFactor::shared_ptr &factor) {
auto lambda = [](const GaussianConditional::shared_ptr &conditional) {
GaussianFactorGraph result;
result.push_back(factor);
return result;
result.push_back(conditional);
if (conditional) {
return GaussianMixtureFactor::GraphAndConstant(
result, conditional->logNormalizationConstant());
} else {
return GaussianMixtureFactor::GraphAndConstant(result, 0.0);
}
};
return {conditionals_, lambda};
}

View File

@ -23,13 +23,13 @@
#include <gtsam/discrete/DecisionTree.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/hybrid/GaussianMixtureFactor.h>
#include <gtsam/hybrid/HybridFactor.h>
#include <gtsam/inference/Conditional.h>
#include <gtsam/linear/GaussianConditional.h>
namespace gtsam {
class GaussianMixtureFactor;
class HybridValues;
/**
@ -60,7 +60,7 @@ class GTSAM_EXPORT GaussianMixture
using BaseConditional = Conditional<HybridFactor, GaussianMixture>;
/// Alias for DecisionTree of GaussianFactorGraphs
using Sum = DecisionTree<Key, GaussianFactorGraph>;
using Sum = DecisionTree<Key, GaussianMixtureFactor::GraphAndConstant>;
/// typedef for Decision Tree of Gaussian Conditionals
using Conditionals = DecisionTree<Key, GaussianConditional::shared_ptr>;

View File

@ -90,11 +90,11 @@ const GaussianMixtureFactor::Mixture GaussianMixtureFactor::factors() const {
/* *******************************************************************************/
GaussianMixtureFactor::Sum GaussianMixtureFactor::add(
const GaussianMixtureFactor::Sum &sum) const {
using Y = GaussianFactorGraph;
using Y = GaussianMixtureFactor::GraphAndConstant;
auto add = [](const Y &graph1, const Y &graph2) {
auto result = graph1;
result.push_back(graph2);
return result;
auto result = graph1.graph;
result.push_back(graph2.graph);
return Y(result, graph1.constant + graph2.constant);
};
const Sum tree = asGaussianFactorGraphTree();
return sum.empty() ? tree : sum.apply(tree, add);
@ -106,7 +106,7 @@ GaussianMixtureFactor::Sum GaussianMixtureFactor::asGaussianFactorGraphTree()
auto wrap = [](const FactorAndConstant &factor_z) {
GaussianFactorGraph result;
result.push_back(factor_z.factor);
return result;
return GaussianMixtureFactor::GraphAndConstant(result, factor_z.constant);
};
return {factors_, wrap};
}

View File

@ -25,10 +25,10 @@
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/hybrid/HybridFactor.h>
#include <gtsam/linear/GaussianFactor.h>
#include <gtsam/linear/GaussianFactorGraph.h>
namespace gtsam {
class GaussianFactorGraph;
class HybridValues;
class DiscreteValues;
class VectorValues;
@ -50,7 +50,6 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
using This = GaussianMixtureFactor;
using shared_ptr = boost::shared_ptr<This>;
using Sum = DecisionTree<Key, GaussianFactorGraph>;
using sharedFactor = boost::shared_ptr<GaussianFactor>;
/// Gaussian factor and log of normalizing constant.
@ -60,8 +59,9 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
// Return error with constant correction.
double error(const VectorValues &values) const {
// Note minus sign: constant is log of normalization constant for probabilities.
// Errors is the negative log-likelihood, hence we subtract the constant here.
// Note: constant is log of normalization constant for probabilities.
// Errors is the negative log-likelihood,
// hence we subtract the constant here.
return factor->error(values) - constant;
}
@ -71,6 +71,22 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
}
};
/// Gaussian factor graph and log of normalizing constant.
struct GraphAndConstant {
GaussianFactorGraph graph;
double constant;
GraphAndConstant(const GaussianFactorGraph &graph, double constant)
: graph(graph), constant(constant) {}
// Check pointer equality.
bool operator==(const GraphAndConstant &other) const {
return graph == other.graph && constant == other.constant;
}
};
using Sum = DecisionTree<Key, GraphAndConstant>;
/// typedef for Decision Tree of Gaussian factors and log-constant.
using Factors = DecisionTree<Key, FactorAndConstant>;
using Mixture = DecisionTree<Key, sharedFactor>;

View File

@ -61,18 +61,19 @@ template class EliminateableFactorGraph<HybridGaussianFactorGraph>;
/* ************************************************************************ */
static GaussianMixtureFactor::Sum &addGaussian(
GaussianMixtureFactor::Sum &sum, const GaussianFactor::shared_ptr &factor) {
using Y = GaussianFactorGraph;
// If the decision tree is not initialized, then initialize it.
if (sum.empty()) {
GaussianFactorGraph result;
result.push_back(factor);
sum = GaussianMixtureFactor::Sum(result);
sum = GaussianMixtureFactor::Sum(
GaussianMixtureFactor::GraphAndConstant(result, 0.0));
} else {
auto add = [&factor](const Y &graph) {
auto result = graph;
auto add = [&factor](
const GaussianMixtureFactor::GraphAndConstant &graph_z) {
auto result = graph_z.graph;
result.push_back(factor);
return result;
return GaussianMixtureFactor::GraphAndConstant(result, graph_z.constant);
};
sum = sum.apply(add);
}
@ -190,31 +191,36 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
DiscreteKeys discreteSeparator(discreteSeparatorSet.begin(),
discreteSeparatorSet.end());
// sum out frontals, this is the factor 𝜏 on the separator
// Collect all the frontal factors to create Gaussian factor graphs
// indexed on the discrete keys.
GaussianMixtureFactor::Sum sum = sumFrontals(factors);
// If a tree leaf contains nullptr,
// convert that leaf to an empty GaussianFactorGraph.
// Needed since the DecisionTree will otherwise create
// a GFG with a single (null) factor.
auto emptyGaussian = [](const GaussianFactorGraph &gfg) {
bool hasNull =
std::any_of(gfg.begin(), gfg.end(),
[](const GaussianFactor::shared_ptr &ptr) { return !ptr; });
auto emptyGaussian =
[](const GaussianMixtureFactor::GraphAndConstant &graph_z) {
bool hasNull = std::any_of(
graph_z.graph.begin(), graph_z.graph.end(),
[](const GaussianFactor::shared_ptr &ptr) { return !ptr; });
return hasNull ? GaussianFactorGraph() : gfg;
};
return hasNull ? GaussianMixtureFactor::GraphAndConstant(
GaussianFactorGraph(), 0.0)
: graph_z;
};
sum = GaussianMixtureFactor::Sum(sum, emptyGaussian);
using EliminationPair = std::pair<boost::shared_ptr<GaussianConditional>,
GaussianMixtureFactor::FactorAndConstant>;
KeyVector keysOfEliminated; // Not the ordering
KeyVector keysOfSeparator; // TODO(frank): Is this just (keys - ordering)?
KeyVector keysOfSeparator;
// This is the elimination method on the leaf nodes
auto eliminate = [&](const GaussianFactorGraph &graph) -> EliminationPair {
if (graph.empty()) {
auto eliminate = [&](const GaussianMixtureFactor::GraphAndConstant &graph_z)
-> EliminationPair {
if (graph_z.graph.empty()) {
return {nullptr, {nullptr, 0.0}};
}
@ -224,7 +230,8 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
std::pair<boost::shared_ptr<GaussianConditional>,
boost::shared_ptr<GaussianFactor>>
conditional_factor = EliminatePreferCholesky(graph, frontalKeys);
conditional_factor =
EliminatePreferCholesky(graph_z.graph, frontalKeys);
// Initialize the keysOfEliminated to be the keys of the
// eliminated GaussianConditional
@ -235,8 +242,10 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
gttoc_(hybrid_eliminate);
#endif
// TODO(Varun) The normalizing constant has to be computed correctly
return {conditional_factor.first, {conditional_factor.second, 0.0}};
GaussianConditional::shared_ptr conditional = conditional_factor.first;
// Get the log of the log normalization constant inverse.
double logZ = -conditional->logNormalizationConstant() + graph_z.constant;
return {conditional, {conditional_factor.second, logZ}};
};
// Perform elimination!
@ -270,6 +279,13 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
}
};
DecisionTree<Key, double> fdt(separatorFactors, factorProb);
// Normalize the values of decision tree to be valid probabilities
double sum = 0.0;
auto visitor = [&](double y) { sum += y; };
fdt.visit(visitor);
// fdt = DecisionTree<Key, double>(fdt,
// [sum](const double &x) { return x / sum;
// });
auto discreteFactor =
boost::make_shared<DecisionTreeFactor>(discreteSeparator, fdt);