GraphAndConstant struct for hybrid elimination

release/4.3a0
Varun Agrawal 2023-01-01 06:36:48 -05:00
parent 2e6f477569
commit a23322350c
5 changed files with 73 additions and 36 deletions

View File

@ -53,11 +53,11 @@ GaussianMixture::GaussianMixture(
/* *******************************************************************************/ /* *******************************************************************************/
GaussianMixture::Sum GaussianMixture::add( GaussianMixture::Sum GaussianMixture::add(
const GaussianMixture::Sum &sum) const { const GaussianMixture::Sum &sum) const {
using Y = GaussianFactorGraph; using Y = GaussianMixtureFactor::GraphAndConstant;
auto add = [](const Y &graph1, const Y &graph2) { auto add = [](const Y &graph1, const Y &graph2) {
auto result = graph1; auto result = graph1.graph;
result.push_back(graph2); result.push_back(graph2.graph);
return result; return Y(result, graph1.constant + graph2.constant);
}; };
const Sum tree = asGaussianFactorGraphTree(); const Sum tree = asGaussianFactorGraphTree();
return sum.empty() ? tree : sum.apply(tree, add); return sum.empty() ? tree : sum.apply(tree, add);
@ -65,10 +65,15 @@ GaussianMixture::Sum GaussianMixture::add(
/* *******************************************************************************/ /* *******************************************************************************/
GaussianMixture::Sum GaussianMixture::asGaussianFactorGraphTree() const { GaussianMixture::Sum GaussianMixture::asGaussianFactorGraphTree() const {
auto lambda = [](const GaussianFactor::shared_ptr &factor) { auto lambda = [](const GaussianConditional::shared_ptr &conditional) {
GaussianFactorGraph result; GaussianFactorGraph result;
result.push_back(factor); result.push_back(conditional);
return result; if (conditional) {
return GaussianMixtureFactor::GraphAndConstant(
result, conditional->logNormalizationConstant());
} else {
return GaussianMixtureFactor::GraphAndConstant(result, 0.0);
}
}; };
return {conditionals_, lambda}; return {conditionals_, lambda};
} }

View File

@ -23,13 +23,13 @@
#include <gtsam/discrete/DecisionTree.h> #include <gtsam/discrete/DecisionTree.h>
#include <gtsam/discrete/DecisionTreeFactor.h> #include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteKey.h> #include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/hybrid/GaussianMixtureFactor.h>
#include <gtsam/hybrid/HybridFactor.h> #include <gtsam/hybrid/HybridFactor.h>
#include <gtsam/inference/Conditional.h> #include <gtsam/inference/Conditional.h>
#include <gtsam/linear/GaussianConditional.h> #include <gtsam/linear/GaussianConditional.h>
namespace gtsam { namespace gtsam {
class GaussianMixtureFactor;
class HybridValues; class HybridValues;
/** /**
@ -60,7 +60,7 @@ class GTSAM_EXPORT GaussianMixture
using BaseConditional = Conditional<HybridFactor, GaussianMixture>; using BaseConditional = Conditional<HybridFactor, GaussianMixture>;
/// Alias for DecisionTree of GaussianFactorGraphs /// Alias for DecisionTree of GaussianFactorGraphs
using Sum = DecisionTree<Key, GaussianFactorGraph>; using Sum = DecisionTree<Key, GaussianMixtureFactor::GraphAndConstant>;
/// typedef for Decision Tree of Gaussian Conditionals /// typedef for Decision Tree of Gaussian Conditionals
using Conditionals = DecisionTree<Key, GaussianConditional::shared_ptr>; using Conditionals = DecisionTree<Key, GaussianConditional::shared_ptr>;

View File

@ -90,11 +90,11 @@ const GaussianMixtureFactor::Mixture GaussianMixtureFactor::factors() const {
/* *******************************************************************************/ /* *******************************************************************************/
GaussianMixtureFactor::Sum GaussianMixtureFactor::add( GaussianMixtureFactor::Sum GaussianMixtureFactor::add(
const GaussianMixtureFactor::Sum &sum) const { const GaussianMixtureFactor::Sum &sum) const {
using Y = GaussianFactorGraph; using Y = GaussianMixtureFactor::GraphAndConstant;
auto add = [](const Y &graph1, const Y &graph2) { auto add = [](const Y &graph1, const Y &graph2) {
auto result = graph1; auto result = graph1.graph;
result.push_back(graph2); result.push_back(graph2.graph);
return result; return Y(result, graph1.constant + graph2.constant);
}; };
const Sum tree = asGaussianFactorGraphTree(); const Sum tree = asGaussianFactorGraphTree();
return sum.empty() ? tree : sum.apply(tree, add); return sum.empty() ? tree : sum.apply(tree, add);
@ -106,7 +106,7 @@ GaussianMixtureFactor::Sum GaussianMixtureFactor::asGaussianFactorGraphTree()
auto wrap = [](const FactorAndConstant &factor_z) { auto wrap = [](const FactorAndConstant &factor_z) {
GaussianFactorGraph result; GaussianFactorGraph result;
result.push_back(factor_z.factor); result.push_back(factor_z.factor);
return result; return GaussianMixtureFactor::GraphAndConstant(result, factor_z.constant);
}; };
return {factors_, wrap}; return {factors_, wrap};
} }

View File

@ -25,10 +25,10 @@
#include <gtsam/discrete/DiscreteKey.h> #include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/hybrid/HybridFactor.h> #include <gtsam/hybrid/HybridFactor.h>
#include <gtsam/linear/GaussianFactor.h> #include <gtsam/linear/GaussianFactor.h>
#include <gtsam/linear/GaussianFactorGraph.h>
namespace gtsam { namespace gtsam {
class GaussianFactorGraph;
class HybridValues; class HybridValues;
class DiscreteValues; class DiscreteValues;
class VectorValues; class VectorValues;
@ -50,7 +50,6 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
using This = GaussianMixtureFactor; using This = GaussianMixtureFactor;
using shared_ptr = boost::shared_ptr<This>; using shared_ptr = boost::shared_ptr<This>;
using Sum = DecisionTree<Key, GaussianFactorGraph>;
using sharedFactor = boost::shared_ptr<GaussianFactor>; using sharedFactor = boost::shared_ptr<GaussianFactor>;
/// Gaussian factor and log of normalizing constant. /// Gaussian factor and log of normalizing constant.
@ -60,8 +59,9 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
// Return error with constant correction. // Return error with constant correction.
double error(const VectorValues &values) const { double error(const VectorValues &values) const {
// Note minus sign: constant is log of normalization constant for probabilities. // Note: constant is log of normalization constant for probabilities.
// Errors is the negative log-likelihood, hence we subtract the constant here. // Errors is the negative log-likelihood,
// hence we subtract the constant here.
return factor->error(values) - constant; return factor->error(values) - constant;
} }
@ -71,6 +71,22 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
} }
}; };
/// Gaussian factor graph and log of normalizing constant.
struct GraphAndConstant {
GaussianFactorGraph graph;
double constant;
GraphAndConstant(const GaussianFactorGraph &graph, double constant)
: graph(graph), constant(constant) {}
// Check pointer equality.
bool operator==(const GraphAndConstant &other) const {
return graph == other.graph && constant == other.constant;
}
};
using Sum = DecisionTree<Key, GraphAndConstant>;
/// typedef for Decision Tree of Gaussian factors and log-constant. /// typedef for Decision Tree of Gaussian factors and log-constant.
using Factors = DecisionTree<Key, FactorAndConstant>; using Factors = DecisionTree<Key, FactorAndConstant>;
using Mixture = DecisionTree<Key, sharedFactor>; using Mixture = DecisionTree<Key, sharedFactor>;

View File

@ -61,18 +61,19 @@ template class EliminateableFactorGraph<HybridGaussianFactorGraph>;
/* ************************************************************************ */ /* ************************************************************************ */
static GaussianMixtureFactor::Sum &addGaussian( static GaussianMixtureFactor::Sum &addGaussian(
GaussianMixtureFactor::Sum &sum, const GaussianFactor::shared_ptr &factor) { GaussianMixtureFactor::Sum &sum, const GaussianFactor::shared_ptr &factor) {
using Y = GaussianFactorGraph;
// If the decision tree is not initialized, then initialize it. // If the decision tree is not initialized, then initialize it.
if (sum.empty()) { if (sum.empty()) {
GaussianFactorGraph result; GaussianFactorGraph result;
result.push_back(factor); result.push_back(factor);
sum = GaussianMixtureFactor::Sum(result); sum = GaussianMixtureFactor::Sum(
GaussianMixtureFactor::GraphAndConstant(result, 0.0));
} else { } else {
auto add = [&factor](const Y &graph) { auto add = [&factor](
auto result = graph; const GaussianMixtureFactor::GraphAndConstant &graph_z) {
auto result = graph_z.graph;
result.push_back(factor); result.push_back(factor);
return result; return GaussianMixtureFactor::GraphAndConstant(result, graph_z.constant);
}; };
sum = sum.apply(add); sum = sum.apply(add);
} }
@ -190,19 +191,23 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
DiscreteKeys discreteSeparator(discreteSeparatorSet.begin(), DiscreteKeys discreteSeparator(discreteSeparatorSet.begin(),
discreteSeparatorSet.end()); discreteSeparatorSet.end());
// sum out frontals, this is the factor 𝜏 on the separator // Collect all the frontal factors to create Gaussian factor graphs
// indexed on the discrete keys.
GaussianMixtureFactor::Sum sum = sumFrontals(factors); GaussianMixtureFactor::Sum sum = sumFrontals(factors);
// If a tree leaf contains nullptr, // If a tree leaf contains nullptr,
// convert that leaf to an empty GaussianFactorGraph. // convert that leaf to an empty GaussianFactorGraph.
// Needed since the DecisionTree will otherwise create // Needed since the DecisionTree will otherwise create
// a GFG with a single (null) factor. // a GFG with a single (null) factor.
auto emptyGaussian = [](const GaussianFactorGraph &gfg) { auto emptyGaussian =
bool hasNull = [](const GaussianMixtureFactor::GraphAndConstant &graph_z) {
std::any_of(gfg.begin(), gfg.end(), bool hasNull = std::any_of(
graph_z.graph.begin(), graph_z.graph.end(),
[](const GaussianFactor::shared_ptr &ptr) { return !ptr; }); [](const GaussianFactor::shared_ptr &ptr) { return !ptr; });
return hasNull ? GaussianFactorGraph() : gfg; return hasNull ? GaussianMixtureFactor::GraphAndConstant(
GaussianFactorGraph(), 0.0)
: graph_z;
}; };
sum = GaussianMixtureFactor::Sum(sum, emptyGaussian); sum = GaussianMixtureFactor::Sum(sum, emptyGaussian);
@ -210,11 +215,12 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
GaussianMixtureFactor::FactorAndConstant>; GaussianMixtureFactor::FactorAndConstant>;
KeyVector keysOfEliminated; // Not the ordering KeyVector keysOfEliminated; // Not the ordering
KeyVector keysOfSeparator; // TODO(frank): Is this just (keys - ordering)? KeyVector keysOfSeparator;
// This is the elimination method on the leaf nodes // This is the elimination method on the leaf nodes
auto eliminate = [&](const GaussianFactorGraph &graph) -> EliminationPair { auto eliminate = [&](const GaussianMixtureFactor::GraphAndConstant &graph_z)
if (graph.empty()) { -> EliminationPair {
if (graph_z.graph.empty()) {
return {nullptr, {nullptr, 0.0}}; return {nullptr, {nullptr, 0.0}};
} }
@ -224,7 +230,8 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
std::pair<boost::shared_ptr<GaussianConditional>, std::pair<boost::shared_ptr<GaussianConditional>,
boost::shared_ptr<GaussianFactor>> boost::shared_ptr<GaussianFactor>>
conditional_factor = EliminatePreferCholesky(graph, frontalKeys); conditional_factor =
EliminatePreferCholesky(graph_z.graph, frontalKeys);
// Initialize the keysOfEliminated to be the keys of the // Initialize the keysOfEliminated to be the keys of the
// eliminated GaussianConditional // eliminated GaussianConditional
@ -235,8 +242,10 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
gttoc_(hybrid_eliminate); gttoc_(hybrid_eliminate);
#endif #endif
// TODO(Varun) The normalizing constant has to be computed correctly GaussianConditional::shared_ptr conditional = conditional_factor.first;
return {conditional_factor.first, {conditional_factor.second, 0.0}}; // Get the log of the log normalization constant inverse.
double logZ = -conditional->logNormalizationConstant() + graph_z.constant;
return {conditional, {conditional_factor.second, logZ}};
}; };
// Perform elimination! // Perform elimination!
@ -270,6 +279,13 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
} }
}; };
DecisionTree<Key, double> fdt(separatorFactors, factorProb); DecisionTree<Key, double> fdt(separatorFactors, factorProb);
// Normalize the values of decision tree to be valid probabilities
double sum = 0.0;
auto visitor = [&](double y) { sum += y; };
fdt.visit(visitor);
// fdt = DecisionTree<Key, double>(fdt,
// [sum](const double &x) { return x / sum;
// });
auto discreteFactor = auto discreteFactor =
boost::make_shared<DecisionTreeFactor>(discreteSeparator, fdt); boost::make_shared<DecisionTreeFactor>(discreteSeparator, fdt);