GraphAndConstant struct for hybrid elimination
parent
2e6f477569
commit
a23322350c
|
|
@ -53,11 +53,11 @@ GaussianMixture::GaussianMixture(
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
GaussianMixture::Sum GaussianMixture::add(
|
GaussianMixture::Sum GaussianMixture::add(
|
||||||
const GaussianMixture::Sum &sum) const {
|
const GaussianMixture::Sum &sum) const {
|
||||||
using Y = GaussianFactorGraph;
|
using Y = GaussianMixtureFactor::GraphAndConstant;
|
||||||
auto add = [](const Y &graph1, const Y &graph2) {
|
auto add = [](const Y &graph1, const Y &graph2) {
|
||||||
auto result = graph1;
|
auto result = graph1.graph;
|
||||||
result.push_back(graph2);
|
result.push_back(graph2.graph);
|
||||||
return result;
|
return Y(result, graph1.constant + graph2.constant);
|
||||||
};
|
};
|
||||||
const Sum tree = asGaussianFactorGraphTree();
|
const Sum tree = asGaussianFactorGraphTree();
|
||||||
return sum.empty() ? tree : sum.apply(tree, add);
|
return sum.empty() ? tree : sum.apply(tree, add);
|
||||||
|
|
@ -65,10 +65,15 @@ GaussianMixture::Sum GaussianMixture::add(
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
GaussianMixture::Sum GaussianMixture::asGaussianFactorGraphTree() const {
|
GaussianMixture::Sum GaussianMixture::asGaussianFactorGraphTree() const {
|
||||||
auto lambda = [](const GaussianFactor::shared_ptr &factor) {
|
auto lambda = [](const GaussianConditional::shared_ptr &conditional) {
|
||||||
GaussianFactorGraph result;
|
GaussianFactorGraph result;
|
||||||
result.push_back(factor);
|
result.push_back(conditional);
|
||||||
return result;
|
if (conditional) {
|
||||||
|
return GaussianMixtureFactor::GraphAndConstant(
|
||||||
|
result, conditional->logNormalizationConstant());
|
||||||
|
} else {
|
||||||
|
return GaussianMixtureFactor::GraphAndConstant(result, 0.0);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
return {conditionals_, lambda};
|
return {conditionals_, lambda};
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -23,13 +23,13 @@
|
||||||
#include <gtsam/discrete/DecisionTree.h>
|
#include <gtsam/discrete/DecisionTree.h>
|
||||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||||
#include <gtsam/discrete/DiscreteKey.h>
|
#include <gtsam/discrete/DiscreteKey.h>
|
||||||
|
#include <gtsam/hybrid/GaussianMixtureFactor.h>
|
||||||
#include <gtsam/hybrid/HybridFactor.h>
|
#include <gtsam/hybrid/HybridFactor.h>
|
||||||
#include <gtsam/inference/Conditional.h>
|
#include <gtsam/inference/Conditional.h>
|
||||||
#include <gtsam/linear/GaussianConditional.h>
|
#include <gtsam/linear/GaussianConditional.h>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
class GaussianMixtureFactor;
|
|
||||||
class HybridValues;
|
class HybridValues;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
@ -60,7 +60,7 @@ class GTSAM_EXPORT GaussianMixture
|
||||||
using BaseConditional = Conditional<HybridFactor, GaussianMixture>;
|
using BaseConditional = Conditional<HybridFactor, GaussianMixture>;
|
||||||
|
|
||||||
/// Alias for DecisionTree of GaussianFactorGraphs
|
/// Alias for DecisionTree of GaussianFactorGraphs
|
||||||
using Sum = DecisionTree<Key, GaussianFactorGraph>;
|
using Sum = DecisionTree<Key, GaussianMixtureFactor::GraphAndConstant>;
|
||||||
|
|
||||||
/// typedef for Decision Tree of Gaussian Conditionals
|
/// typedef for Decision Tree of Gaussian Conditionals
|
||||||
using Conditionals = DecisionTree<Key, GaussianConditional::shared_ptr>;
|
using Conditionals = DecisionTree<Key, GaussianConditional::shared_ptr>;
|
||||||
|
|
|
||||||
|
|
@ -90,11 +90,11 @@ const GaussianMixtureFactor::Mixture GaussianMixtureFactor::factors() const {
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
GaussianMixtureFactor::Sum GaussianMixtureFactor::add(
|
GaussianMixtureFactor::Sum GaussianMixtureFactor::add(
|
||||||
const GaussianMixtureFactor::Sum &sum) const {
|
const GaussianMixtureFactor::Sum &sum) const {
|
||||||
using Y = GaussianFactorGraph;
|
using Y = GaussianMixtureFactor::GraphAndConstant;
|
||||||
auto add = [](const Y &graph1, const Y &graph2) {
|
auto add = [](const Y &graph1, const Y &graph2) {
|
||||||
auto result = graph1;
|
auto result = graph1.graph;
|
||||||
result.push_back(graph2);
|
result.push_back(graph2.graph);
|
||||||
return result;
|
return Y(result, graph1.constant + graph2.constant);
|
||||||
};
|
};
|
||||||
const Sum tree = asGaussianFactorGraphTree();
|
const Sum tree = asGaussianFactorGraphTree();
|
||||||
return sum.empty() ? tree : sum.apply(tree, add);
|
return sum.empty() ? tree : sum.apply(tree, add);
|
||||||
|
|
@ -106,7 +106,7 @@ GaussianMixtureFactor::Sum GaussianMixtureFactor::asGaussianFactorGraphTree()
|
||||||
auto wrap = [](const FactorAndConstant &factor_z) {
|
auto wrap = [](const FactorAndConstant &factor_z) {
|
||||||
GaussianFactorGraph result;
|
GaussianFactorGraph result;
|
||||||
result.push_back(factor_z.factor);
|
result.push_back(factor_z.factor);
|
||||||
return result;
|
return GaussianMixtureFactor::GraphAndConstant(result, factor_z.constant);
|
||||||
};
|
};
|
||||||
return {factors_, wrap};
|
return {factors_, wrap};
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -25,10 +25,10 @@
|
||||||
#include <gtsam/discrete/DiscreteKey.h>
|
#include <gtsam/discrete/DiscreteKey.h>
|
||||||
#include <gtsam/hybrid/HybridFactor.h>
|
#include <gtsam/hybrid/HybridFactor.h>
|
||||||
#include <gtsam/linear/GaussianFactor.h>
|
#include <gtsam/linear/GaussianFactor.h>
|
||||||
|
#include <gtsam/linear/GaussianFactorGraph.h>
|
||||||
|
|
||||||
namespace gtsam {
|
namespace gtsam {
|
||||||
|
|
||||||
class GaussianFactorGraph;
|
|
||||||
class HybridValues;
|
class HybridValues;
|
||||||
class DiscreteValues;
|
class DiscreteValues;
|
||||||
class VectorValues;
|
class VectorValues;
|
||||||
|
|
@ -50,7 +50,6 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
|
||||||
using This = GaussianMixtureFactor;
|
using This = GaussianMixtureFactor;
|
||||||
using shared_ptr = boost::shared_ptr<This>;
|
using shared_ptr = boost::shared_ptr<This>;
|
||||||
|
|
||||||
using Sum = DecisionTree<Key, GaussianFactorGraph>;
|
|
||||||
using sharedFactor = boost::shared_ptr<GaussianFactor>;
|
using sharedFactor = boost::shared_ptr<GaussianFactor>;
|
||||||
|
|
||||||
/// Gaussian factor and log of normalizing constant.
|
/// Gaussian factor and log of normalizing constant.
|
||||||
|
|
@ -60,8 +59,9 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
|
||||||
|
|
||||||
// Return error with constant correction.
|
// Return error with constant correction.
|
||||||
double error(const VectorValues &values) const {
|
double error(const VectorValues &values) const {
|
||||||
// Note minus sign: constant is log of normalization constant for probabilities.
|
// Note: constant is log of normalization constant for probabilities.
|
||||||
// Errors is the negative log-likelihood, hence we subtract the constant here.
|
// Errors is the negative log-likelihood,
|
||||||
|
// hence we subtract the constant here.
|
||||||
return factor->error(values) - constant;
|
return factor->error(values) - constant;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -71,6 +71,22 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// Gaussian factor graph and log of normalizing constant.
|
||||||
|
struct GraphAndConstant {
|
||||||
|
GaussianFactorGraph graph;
|
||||||
|
double constant;
|
||||||
|
|
||||||
|
GraphAndConstant(const GaussianFactorGraph &graph, double constant)
|
||||||
|
: graph(graph), constant(constant) {}
|
||||||
|
|
||||||
|
// Check pointer equality.
|
||||||
|
bool operator==(const GraphAndConstant &other) const {
|
||||||
|
return graph == other.graph && constant == other.constant;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
using Sum = DecisionTree<Key, GraphAndConstant>;
|
||||||
|
|
||||||
/// typedef for Decision Tree of Gaussian factors and log-constant.
|
/// typedef for Decision Tree of Gaussian factors and log-constant.
|
||||||
using Factors = DecisionTree<Key, FactorAndConstant>;
|
using Factors = DecisionTree<Key, FactorAndConstant>;
|
||||||
using Mixture = DecisionTree<Key, sharedFactor>;
|
using Mixture = DecisionTree<Key, sharedFactor>;
|
||||||
|
|
|
||||||
|
|
@ -61,18 +61,19 @@ template class EliminateableFactorGraph<HybridGaussianFactorGraph>;
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
static GaussianMixtureFactor::Sum &addGaussian(
|
static GaussianMixtureFactor::Sum &addGaussian(
|
||||||
GaussianMixtureFactor::Sum &sum, const GaussianFactor::shared_ptr &factor) {
|
GaussianMixtureFactor::Sum &sum, const GaussianFactor::shared_ptr &factor) {
|
||||||
using Y = GaussianFactorGraph;
|
|
||||||
// If the decision tree is not initialized, then initialize it.
|
// If the decision tree is not initialized, then initialize it.
|
||||||
if (sum.empty()) {
|
if (sum.empty()) {
|
||||||
GaussianFactorGraph result;
|
GaussianFactorGraph result;
|
||||||
result.push_back(factor);
|
result.push_back(factor);
|
||||||
sum = GaussianMixtureFactor::Sum(result);
|
sum = GaussianMixtureFactor::Sum(
|
||||||
|
GaussianMixtureFactor::GraphAndConstant(result, 0.0));
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
auto add = [&factor](const Y &graph) {
|
auto add = [&factor](
|
||||||
auto result = graph;
|
const GaussianMixtureFactor::GraphAndConstant &graph_z) {
|
||||||
|
auto result = graph_z.graph;
|
||||||
result.push_back(factor);
|
result.push_back(factor);
|
||||||
return result;
|
return GaussianMixtureFactor::GraphAndConstant(result, graph_z.constant);
|
||||||
};
|
};
|
||||||
sum = sum.apply(add);
|
sum = sum.apply(add);
|
||||||
}
|
}
|
||||||
|
|
@ -190,19 +191,23 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
|
||||||
DiscreteKeys discreteSeparator(discreteSeparatorSet.begin(),
|
DiscreteKeys discreteSeparator(discreteSeparatorSet.begin(),
|
||||||
discreteSeparatorSet.end());
|
discreteSeparatorSet.end());
|
||||||
|
|
||||||
// sum out frontals, this is the factor 𝜏 on the separator
|
// Collect all the frontal factors to create Gaussian factor graphs
|
||||||
|
// indexed on the discrete keys.
|
||||||
GaussianMixtureFactor::Sum sum = sumFrontals(factors);
|
GaussianMixtureFactor::Sum sum = sumFrontals(factors);
|
||||||
|
|
||||||
// If a tree leaf contains nullptr,
|
// If a tree leaf contains nullptr,
|
||||||
// convert that leaf to an empty GaussianFactorGraph.
|
// convert that leaf to an empty GaussianFactorGraph.
|
||||||
// Needed since the DecisionTree will otherwise create
|
// Needed since the DecisionTree will otherwise create
|
||||||
// a GFG with a single (null) factor.
|
// a GFG with a single (null) factor.
|
||||||
auto emptyGaussian = [](const GaussianFactorGraph &gfg) {
|
auto emptyGaussian =
|
||||||
bool hasNull =
|
[](const GaussianMixtureFactor::GraphAndConstant &graph_z) {
|
||||||
std::any_of(gfg.begin(), gfg.end(),
|
bool hasNull = std::any_of(
|
||||||
|
graph_z.graph.begin(), graph_z.graph.end(),
|
||||||
[](const GaussianFactor::shared_ptr &ptr) { return !ptr; });
|
[](const GaussianFactor::shared_ptr &ptr) { return !ptr; });
|
||||||
|
|
||||||
return hasNull ? GaussianFactorGraph() : gfg;
|
return hasNull ? GaussianMixtureFactor::GraphAndConstant(
|
||||||
|
GaussianFactorGraph(), 0.0)
|
||||||
|
: graph_z;
|
||||||
};
|
};
|
||||||
sum = GaussianMixtureFactor::Sum(sum, emptyGaussian);
|
sum = GaussianMixtureFactor::Sum(sum, emptyGaussian);
|
||||||
|
|
||||||
|
|
@ -210,11 +215,12 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
|
||||||
GaussianMixtureFactor::FactorAndConstant>;
|
GaussianMixtureFactor::FactorAndConstant>;
|
||||||
|
|
||||||
KeyVector keysOfEliminated; // Not the ordering
|
KeyVector keysOfEliminated; // Not the ordering
|
||||||
KeyVector keysOfSeparator; // TODO(frank): Is this just (keys - ordering)?
|
KeyVector keysOfSeparator;
|
||||||
|
|
||||||
// This is the elimination method on the leaf nodes
|
// This is the elimination method on the leaf nodes
|
||||||
auto eliminate = [&](const GaussianFactorGraph &graph) -> EliminationPair {
|
auto eliminate = [&](const GaussianMixtureFactor::GraphAndConstant &graph_z)
|
||||||
if (graph.empty()) {
|
-> EliminationPair {
|
||||||
|
if (graph_z.graph.empty()) {
|
||||||
return {nullptr, {nullptr, 0.0}};
|
return {nullptr, {nullptr, 0.0}};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -224,7 +230,8 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
|
||||||
|
|
||||||
std::pair<boost::shared_ptr<GaussianConditional>,
|
std::pair<boost::shared_ptr<GaussianConditional>,
|
||||||
boost::shared_ptr<GaussianFactor>>
|
boost::shared_ptr<GaussianFactor>>
|
||||||
conditional_factor = EliminatePreferCholesky(graph, frontalKeys);
|
conditional_factor =
|
||||||
|
EliminatePreferCholesky(graph_z.graph, frontalKeys);
|
||||||
|
|
||||||
// Initialize the keysOfEliminated to be the keys of the
|
// Initialize the keysOfEliminated to be the keys of the
|
||||||
// eliminated GaussianConditional
|
// eliminated GaussianConditional
|
||||||
|
|
@ -235,8 +242,10 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
|
||||||
gttoc_(hybrid_eliminate);
|
gttoc_(hybrid_eliminate);
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// TODO(Varun) The normalizing constant has to be computed correctly
|
GaussianConditional::shared_ptr conditional = conditional_factor.first;
|
||||||
return {conditional_factor.first, {conditional_factor.second, 0.0}};
|
// Get the log of the log normalization constant inverse.
|
||||||
|
double logZ = -conditional->logNormalizationConstant() + graph_z.constant;
|
||||||
|
return {conditional, {conditional_factor.second, logZ}};
|
||||||
};
|
};
|
||||||
|
|
||||||
// Perform elimination!
|
// Perform elimination!
|
||||||
|
|
@ -270,6 +279,13 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
DecisionTree<Key, double> fdt(separatorFactors, factorProb);
|
DecisionTree<Key, double> fdt(separatorFactors, factorProb);
|
||||||
|
// Normalize the values of decision tree to be valid probabilities
|
||||||
|
double sum = 0.0;
|
||||||
|
auto visitor = [&](double y) { sum += y; };
|
||||||
|
fdt.visit(visitor);
|
||||||
|
// fdt = DecisionTree<Key, double>(fdt,
|
||||||
|
// [sum](const double &x) { return x / sum;
|
||||||
|
// });
|
||||||
|
|
||||||
auto discreteFactor =
|
auto discreteFactor =
|
||||||
boost::make_shared<DecisionTreeFactor>(discreteSeparator, fdt);
|
boost::make_shared<DecisionTreeFactor>(discreteSeparator, fdt);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue