name change of Sum to GaussianFactorGraphTree and SumFrontals to assembleGraphTree

release/4.3a0
Frank Dellaert 2023-01-02 13:11:22 -05:00
parent 06aed5397a
commit f8d75abfeb
10 changed files with 118 additions and 98 deletions

View File

@ -51,28 +51,28 @@ GaussianMixture::GaussianMixture(
Conditionals(discreteParents, conditionalsList)) {} Conditionals(discreteParents, conditionalsList)) {}
/* *******************************************************************************/ /* *******************************************************************************/
GaussianMixture::Sum GaussianMixture::add( GaussianFactorGraphTree GaussianMixture::add(
const GaussianMixture::Sum &sum) const { const GaussianFactorGraphTree &sum) const {
using Y = GaussianMixtureFactor::GraphAndConstant; using Y = GraphAndConstant;
auto add = [](const Y &graph1, const Y &graph2) { auto add = [](const Y &graph1, const Y &graph2) {
auto result = graph1.graph; auto result = graph1.graph;
result.push_back(graph2.graph); result.push_back(graph2.graph);
return Y(result, graph1.constant + graph2.constant); return Y(result, graph1.constant + graph2.constant);
}; };
const Sum tree = asGaussianFactorGraphTree(); const auto tree = asGaussianFactorGraphTree();
return sum.empty() ? tree : sum.apply(tree, add); return sum.empty() ? tree : sum.apply(tree, add);
} }
/* *******************************************************************************/ /* *******************************************************************************/
GaussianMixture::Sum GaussianMixture::asGaussianFactorGraphTree() const { GaussianFactorGraphTree GaussianMixture::asGaussianFactorGraphTree() const {
auto lambda = [](const GaussianConditional::shared_ptr &conditional) { auto lambda = [](const GaussianConditional::shared_ptr &conditional) {
GaussianFactorGraph result; GaussianFactorGraph result;
result.push_back(conditional); result.push_back(conditional);
if (conditional) { if (conditional) {
return GaussianMixtureFactor::GraphAndConstant( return GraphAndConstant(
result, conditional->logNormalizationConstant()); result, conditional->logNormalizationConstant());
} else { } else {
return GaussianMixtureFactor::GraphAndConstant(result, 0.0); return GraphAndConstant(result, 0.0);
} }
}; };
return {conditionals_, lambda}; return {conditionals_, lambda};

View File

@ -59,9 +59,6 @@ class GTSAM_EXPORT GaussianMixture
using BaseFactor = HybridFactor; using BaseFactor = HybridFactor;
using BaseConditional = Conditional<HybridFactor, GaussianMixture>; using BaseConditional = Conditional<HybridFactor, GaussianMixture>;
/// Alias for DecisionTree of GaussianFactorGraphs
using Sum = DecisionTree<Key, GaussianMixtureFactor::GraphAndConstant>;
/// typedef for Decision Tree of Gaussian Conditionals /// typedef for Decision Tree of Gaussian Conditionals
using Conditionals = DecisionTree<Key, GaussianConditional::shared_ptr>; using Conditionals = DecisionTree<Key, GaussianConditional::shared_ptr>;
@ -71,7 +68,7 @@ class GTSAM_EXPORT GaussianMixture
/** /**
* @brief Convert a DecisionTree of factors into a DT of Gaussian FGs. * @brief Convert a DecisionTree of factors into a DT of Gaussian FGs.
*/ */
Sum asGaussianFactorGraphTree() const; GaussianFactorGraphTree asGaussianFactorGraphTree() const;
/** /**
* @brief Helper function to get the pruner functor. * @brief Helper function to get the pruner functor.
@ -172,6 +169,16 @@ class GTSAM_EXPORT GaussianMixture
*/ */
double error(const HybridValues &values) const override; double error(const HybridValues &values) const override;
// /// Calculate probability density for given values `x`.
// double evaluate(const HybridValues &values) const;
// /// Evaluate probability density, sugar.
// double operator()(const HybridValues &values) const { return
// evaluate(values); }
// /// Calculate log-density for given values `x`.
// double logDensity(const HybridValues &values) const;
/** /**
* @brief Prune the decision tree of Gaussian factors as per the discrete * @brief Prune the decision tree of Gaussian factors as per the discrete
* `decisionTree`. * `decisionTree`.
@ -186,9 +193,9 @@ class GTSAM_EXPORT GaussianMixture
* maintaining the decision tree structure. * maintaining the decision tree structure.
* *
* @param sum Decision Tree of Gaussian Factor Graphs * @param sum Decision Tree of Gaussian Factor Graphs
* @return Sum * @return GaussianFactorGraphTree
*/ */
Sum add(const Sum &sum) const; GaussianFactorGraphTree add(const GaussianFactorGraphTree &sum) const;
/// @} /// @}
}; };

View File

@ -100,25 +100,25 @@ double GaussianMixtureFactor::constant(const DiscreteValues &assignment) const {
// } // }
/* *******************************************************************************/ /* *******************************************************************************/
GaussianMixtureFactor::Sum GaussianMixtureFactor::add( GaussianFactorGraphTree GaussianMixtureFactor::add(
const GaussianMixtureFactor::Sum &sum) const { const GaussianFactorGraphTree &sum) const {
using Y = GaussianMixtureFactor::GraphAndConstant; using Y = GraphAndConstant;
auto add = [](const Y &graph1, const Y &graph2) { auto add = [](const Y &graph1, const Y &graph2) {
auto result = graph1.graph; auto result = graph1.graph;
result.push_back(graph2.graph); result.push_back(graph2.graph);
return Y(result, graph1.constant + graph2.constant); return Y(result, graph1.constant + graph2.constant);
}; };
const Sum tree = asGaussianFactorGraphTree(); const auto tree = asGaussianFactorGraphTree();
return sum.empty() ? tree : sum.apply(tree, add); return sum.empty() ? tree : sum.apply(tree, add);
} }
/* *******************************************************************************/ /* *******************************************************************************/
GaussianMixtureFactor::Sum GaussianMixtureFactor::asGaussianFactorGraphTree() GaussianFactorGraphTree GaussianMixtureFactor::asGaussianFactorGraphTree()
const { const {
auto wrap = [](const FactorAndConstant &factor_z) { auto wrap = [](const FactorAndConstant &factor_z) {
GaussianFactorGraph result; GaussianFactorGraph result;
result.push_back(factor_z.factor); result.push_back(factor_z.factor);
return GaussianMixtureFactor::GraphAndConstant(result, factor_z.constant); return GraphAndConstant(result, factor_z.constant);
}; };
return {factors_, wrap}; return {factors_, wrap};
} }

View File

@ -72,35 +72,6 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
} }
}; };
/// Gaussian factor graph and log of normalizing constant.
struct GraphAndConstant {
GaussianFactorGraph graph;
double constant;
GraphAndConstant(const GaussianFactorGraph &graph, double constant)
: graph(graph), constant(constant) {}
// Check pointer equality.
bool operator==(const GraphAndConstant &other) const {
return graph == other.graph && constant == other.constant;
}
// Implement GTSAM-style print:
void print(const std::string &s = "Graph: ",
const KeyFormatter &formatter = DefaultKeyFormatter) const {
graph.print(s, formatter);
std::cout << "Constant: " << constant << std::endl;
}
// Implement GTSAM-style equals:
bool equals(const GraphAndConstant &other, double tol = 1e-9) const {
return graph.equals(other.graph, tol) &&
fabs(constant - other.constant) < tol;
}
};
using Sum = DecisionTree<Key, GraphAndConstant>;
/// typedef for Decision Tree of Gaussian factors and log-constant. /// typedef for Decision Tree of Gaussian factors and log-constant.
using Factors = DecisionTree<Key, FactorAndConstant>; using Factors = DecisionTree<Key, FactorAndConstant>;
using Mixture = DecisionTree<Key, sharedFactor>; using Mixture = DecisionTree<Key, sharedFactor>;
@ -113,9 +84,9 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
* @brief Helper function to return factors and functional to create a * @brief Helper function to return factors and functional to create a
* DecisionTree of Gaussian Factor Graphs. * DecisionTree of Gaussian Factor Graphs.
* *
* @return Sum (DecisionTree<Key, GaussianFactorGraph>) * @return GaussianFactorGraphTree
*/ */
Sum asGaussianFactorGraphTree() const; GaussianFactorGraphTree asGaussianFactorGraphTree() const;
public: public:
/// @name Constructors /// @name Constructors
@ -184,7 +155,7 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
* variables. * variables.
* @return Sum * @return Sum
*/ */
Sum add(const Sum &sum) const; GaussianFactorGraphTree add(const GaussianFactorGraphTree &sum) const;
/** /**
* @brief Compute error of the GaussianMixtureFactor as a tree. * @brief Compute error of the GaussianMixtureFactor as a tree.
@ -202,7 +173,8 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
double error(const HybridValues &values) const override; double error(const HybridValues &values) const override;
/// Add MixtureFactor to a Sum, syntactic sugar. /// Add MixtureFactor to a Sum, syntactic sugar.
friend Sum &operator+=(Sum &sum, const GaussianMixtureFactor &factor) { friend GaussianFactorGraphTree &operator+=(
GaussianFactorGraphTree &sum, const GaussianMixtureFactor &factor) {
sum = factor.add(sum); sum = factor.add(sum);
return sum; return sum;
} }
@ -215,7 +187,6 @@ struct traits<GaussianMixtureFactor> : public Testable<GaussianMixtureFactor> {
}; };
template <> template <>
struct traits<GaussianMixtureFactor::GraphAndConstant> struct traits<GraphAndConstant> : public Testable<GraphAndConstant> {};
: public Testable<GaussianMixtureFactor::GraphAndConstant> {};
} // namespace gtsam } // namespace gtsam

View File

@ -279,18 +279,31 @@ double HybridBayesNet::evaluate(const HybridValues &values) const {
const VectorValues &continuousValues = values.continuous(); const VectorValues &continuousValues = values.continuous();
double logDensity = 0.0, probability = 1.0; double logDensity = 0.0, probability = 1.0;
bool debug = false;
// Iterate over each conditional. // Iterate over each conditional.
for (auto &&conditional : *this) { for (auto &&conditional : *this) {
// TODO: should be delegated to derived classes.
if (auto gm = conditional->asMixture()) { if (auto gm = conditional->asMixture()) {
const auto component = (*gm)(discreteValues); const auto component = (*gm)(discreteValues);
logDensity += component->logDensity(continuousValues); logDensity += component->logDensity(continuousValues);
if (debug) {
GTSAM_PRINT(continuousValues);
std::cout << "component->logDensity(continuousValues) = "
<< component->logDensity(continuousValues) << std::endl;
}
} else if (auto gc = conditional->asGaussian()) { } else if (auto gc = conditional->asGaussian()) {
// If continuous only, evaluate the probability and multiply. // If continuous only, evaluate the probability and multiply.
logDensity += gc->logDensity(continuousValues); logDensity += gc->logDensity(continuousValues);
if (debug)
std::cout << "gc->logDensity(continuousValues) = "
<< gc->logDensity(continuousValues) << std::endl;
} else if (auto dc = conditional->asDiscrete()) { } else if (auto dc = conditional->asDiscrete()) {
// Conditional is discrete-only, so return its probability. // Conditional is discrete-only, so return its probability.
probability *= dc->operator()(discreteValues); probability *= dc->operator()(discreteValues);
if (debug)
std::cout << "dc->operator()(discreteValues) = "
<< dc->operator()(discreteValues) << std::endl;
} }
} }

View File

@ -21,6 +21,8 @@
#include <gtsam/discrete/DiscreteKey.h> #include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/inference/Factor.h> #include <gtsam/inference/Factor.h>
#include <gtsam/nonlinear/Values.h> #include <gtsam/nonlinear/Values.h>
#include <gtsam/linear/GaussianFactorGraph.h>
#include <gtsam/discrete/DecisionTree.h>
#include <cstddef> #include <cstddef>
#include <string> #include <string>
@ -28,6 +30,36 @@ namespace gtsam {
class HybridValues; class HybridValues;
/// Gaussian factor graph and log of normalizing constant.
struct GraphAndConstant {
GaussianFactorGraph graph;
double constant;
GraphAndConstant(const GaussianFactorGraph &graph, double constant)
: graph(graph), constant(constant) {}
// Check pointer equality.
bool operator==(const GraphAndConstant &other) const {
return graph == other.graph && constant == other.constant;
}
// Implement GTSAM-style print:
void print(const std::string &s = "Graph: ",
const KeyFormatter &formatter = DefaultKeyFormatter) const {
graph.print(s, formatter);
std::cout << "Constant: " << constant << std::endl;
}
// Implement GTSAM-style equals:
bool equals(const GraphAndConstant &other, double tol = 1e-9) const {
return graph.equals(other.graph, tol) &&
fabs(constant - other.constant) < tol;
}
};
/// Alias for DecisionTree of GaussianFactorGraphs
using GaussianFactorGraphTree = DecisionTree<Key, GraphAndConstant>;
KeyVector CollectKeys(const KeyVector &continuousKeys, KeyVector CollectKeys(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys); const DiscreteKeys &discreteKeys);
KeyVector CollectKeys(const KeyVector &keys1, const KeyVector &keys2); KeyVector CollectKeys(const KeyVector &keys1, const KeyVector &keys2);

View File

@ -59,51 +59,44 @@ namespace gtsam {
template class EliminateableFactorGraph<HybridGaussianFactorGraph>; template class EliminateableFactorGraph<HybridGaussianFactorGraph>;
/* ************************************************************************ */ /* ************************************************************************ */
static GaussianMixtureFactor::Sum addGaussian( static GaussianFactorGraphTree addGaussian(
const GaussianMixtureFactor::Sum &sum, const GaussianFactorGraphTree &sum,
const GaussianFactor::shared_ptr &factor) { const GaussianFactor::shared_ptr &factor) {
// If the decision tree is not initialized, then initialize it. // If the decision tree is not initialized, then initialize it.
if (sum.empty()) { if (sum.empty()) {
GaussianFactorGraph result; GaussianFactorGraph result;
result.push_back(factor); result.push_back(factor);
return GaussianMixtureFactor::Sum( return GaussianFactorGraphTree(GraphAndConstant(result, 0.0));
GaussianMixtureFactor::GraphAndConstant(result, 0.0));
} else { } else {
auto add = [&factor]( auto add = [&factor](const GraphAndConstant &graph_z) {
const GaussianMixtureFactor::GraphAndConstant &graph_z) {
auto result = graph_z.graph; auto result = graph_z.graph;
result.push_back(factor); result.push_back(factor);
return GaussianMixtureFactor::GraphAndConstant(result, graph_z.constant); return GraphAndConstant(result, graph_z.constant);
}; };
return sum.apply(add); return sum.apply(add);
} }
} }
/* ************************************************************************ */ /* ************************************************************************ */
// TODO(dellaert): At the time I though "Sum" was a good name, but coming back // TODO(dellaert): We need to document why deferredFactors need to be
// to it after a while I think "SumFrontals" is better.it's a terrible name. // added last, which I would undo if possible. Implementation-wise, it's
// Also, the implementation is inconsistent. I think we should just have a // probably more efficient to first collect the discrete keys, and then loop
// virtual method in HybridFactor that adds to the "Sum" object, like // over all assignments to populate a vector.
// addGaussian. Finally, we need to document why deferredFactors need to be GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
// added last, which I would undo if possible. gttic(assembleGraphTree);
// Implementation-wise, it's probably more efficient to first collect the
// discrete keys, and then loop over all assignments to populate a vector.
GaussianMixtureFactor::Sum HybridGaussianFactorGraph::SumFrontals() const {
// sum out frontals, this is the factor on the separator
gttic(sum);
GaussianMixtureFactor::Sum sum; GaussianFactorGraphTree result;
std::vector<GaussianFactor::shared_ptr> deferredFactors; std::vector<GaussianFactor::shared_ptr> deferredFactors;
for (auto &f : factors_) { for (auto &f : factors_) {
// TODO(dellaert): just use a virtual method defined in HybridFactor. // TODO(dellaert): just use a virtual method defined in HybridFactor.
if (f->isHybrid()) { if (f->isHybrid()) {
if (auto gm = boost::dynamic_pointer_cast<GaussianMixtureFactor>(f)) { if (auto gm = boost::dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
sum = gm->add(sum); result = gm->add(result);
} }
if (auto gm = boost::dynamic_pointer_cast<HybridConditional>(f)) { if (auto gm = boost::dynamic_pointer_cast<HybridConditional>(f)) {
sum = gm->asMixture()->add(sum); result = gm->asMixture()->add(result);
} }
} else if (f->isContinuous()) { } else if (f->isContinuous()) {
@ -134,12 +127,12 @@ GaussianMixtureFactor::Sum HybridGaussianFactorGraph::SumFrontals() const {
} }
for (auto &f : deferredFactors) { for (auto &f : deferredFactors) {
sum = addGaussian(sum, f); result = addGaussian(result, f);
} }
gttoc(sum); gttoc(assembleGraphTree);
return sum; return result;
} }
/* ************************************************************************ */ /* ************************************************************************ */
@ -192,18 +185,14 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
// If any GaussianFactorGraph in the decision tree contains a nullptr, convert // If any GaussianFactorGraph in the decision tree contains a nullptr, convert
// that leaf to an empty GaussianFactorGraph. Needed since the DecisionTree will // that leaf to an empty GaussianFactorGraph. Needed since the DecisionTree will
// otherwise create a GFG with a single (null) factor. // otherwise create a GFG with a single (null) factor.
GaussianMixtureFactor::Sum removeEmpty(const GaussianMixtureFactor::Sum &sum) { GaussianFactorGraphTree removeEmpty(const GaussianFactorGraphTree &sum) {
auto emptyGaussian = [](const GaussianMixtureFactor::GraphAndConstant auto emptyGaussian = [](const GraphAndConstant &graph_z) {
&graph_z) {
bool hasNull = bool hasNull =
std::any_of(graph_z.graph.begin(), graph_z.graph.end(), std::any_of(graph_z.graph.begin(), graph_z.graph.end(),
[](const GaussianFactor::shared_ptr &ptr) { return !ptr; }); [](const GaussianFactor::shared_ptr &ptr) { return !ptr; });
return hasNull return hasNull ? GraphAndConstant{GaussianFactorGraph(), 0.0} : graph_z;
? GaussianMixtureFactor::GraphAndConstant{GaussianFactorGraph(),
0.0}
: graph_z;
}; };
return GaussianMixtureFactor::Sum(sum, emptyGaussian); return GaussianFactorGraphTree(sum, emptyGaussian);
} }
/* ************************************************************************ */ /* ************************************************************************ */
static std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr> static std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr>
@ -218,17 +207,16 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
// Collect all the factors to create a set of Gaussian factor graphs in a // Collect all the factors to create a set of Gaussian factor graphs in a
// decision tree indexed by all discrete keys involved. // decision tree indexed by all discrete keys involved.
GaussianMixtureFactor::Sum sum = factors.SumFrontals(); GaussianFactorGraphTree sum = factors.assembleGraphTree();
// TODO(dellaert): does SumFrontals not guarantee we do not need this? // TODO(dellaert): does assembleGraphTree not guarantee we do not need this?
sum = removeEmpty(sum); sum = removeEmpty(sum);
using EliminationPair = std::pair<boost::shared_ptr<GaussianConditional>, using EliminationPair = std::pair<boost::shared_ptr<GaussianConditional>,
GaussianMixtureFactor::FactorAndConstant>; GaussianMixtureFactor::FactorAndConstant>;
// This is the elimination method on the leaf nodes // This is the elimination method on the leaf nodes
auto eliminate = [&](const GaussianMixtureFactor::GraphAndConstant &graph_z) auto eliminate = [&](const GraphAndConstant &graph_z) -> EliminationPair {
-> EliminationPair {
if (graph_z.graph.empty()) { if (graph_z.graph.empty()) {
return {nullptr, {nullptr, 0.0}}; return {nullptr, {nullptr, 0.0}};
} }
@ -247,7 +235,15 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
#endif #endif
// Get the log of the log normalization constant inverse. // Get the log of the log normalization constant inverse.
double logZ = graph_z.constant - conditional->logNormalizationConstant(); double logZ = -conditional->logNormalizationConstant();
// IF this is the last continuous variable to eliminated, we need to
// calculate the error here: the value of all factors at the mean, see
// ml_map_rao.pdf.
if (continuousSeparator.empty()) {
const auto posterior_mean = conditional->solve(VectorValues());
logZ += graph_z.graph.error(posterior_mean);
}
return {conditional, {newFactor, logZ}}; return {conditional, {newFactor, logZ}};
}; };

View File

@ -245,7 +245,7 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
* one for A and one for B. The leaves of the tree will be the Gaussian * one for A and one for B. The leaves of the tree will be the Gaussian
* factors that have only continuous keys. * factors that have only continuous keys.
*/ */
GaussianMixtureFactor::Sum SumFrontals() const; GaussianFactorGraphTree assembleGraphTree() const;
/// @} /// @}
}; };

View File

@ -80,7 +80,7 @@ TEST(GaussianMixtureFactor, Sum) {
// Create sum of two mixture factors: it will be a decision tree now on both // Create sum of two mixture factors: it will be a decision tree now on both
// discrete variables m1 and m2: // discrete variables m1 and m2:
GaussianMixtureFactor::Sum sum; GaussianFactorGraphTree sum;
sum += mixtureFactorA; sum += mixtureFactorA;
sum += mixtureFactorB; sum += mixtureFactorB;

View File

@ -615,15 +615,16 @@ TEST(HybridGaussianFactorGraph, ErrorAndProbPrimeTree) {
} }
/* ****************************************************************************/ /* ****************************************************************************/
// Check that SumFrontals assembles Gaussian factor graphs for each assignment. // Check that assembleGraphTree assembles Gaussian factor graphs for each
TEST(HybridGaussianFactorGraph, SumFrontals) { // assignment.
TEST(HybridGaussianFactorGraph, assembleGraphTree) {
const int num_measurements = 1; const int num_measurements = 1;
const bool deterministic = true; const bool deterministic = true;
auto fg = auto fg =
tiny::createHybridGaussianFactorGraph(num_measurements, deterministic); tiny::createHybridGaussianFactorGraph(num_measurements, deterministic);
EXPECT_LONGS_EQUAL(3, fg.size()); EXPECT_LONGS_EQUAL(3, fg.size());
auto sum = fg.SumFrontals(); auto sum = fg.assembleGraphTree();
// Get mixture factor: // Get mixture factor:
auto mixture = boost::dynamic_pointer_cast<GaussianMixtureFactor>(fg.at(0)); auto mixture = boost::dynamic_pointer_cast<GaussianMixtureFactor>(fg.at(0));
@ -638,7 +639,7 @@ TEST(HybridGaussianFactorGraph, SumFrontals) {
// Expected decision tree with two factor graphs: // Expected decision tree with two factor graphs:
// f(x0;mode=0)P(x0) and f(x0;mode=1)P(x0) // f(x0;mode=0)P(x0) and f(x0;mode=1)P(x0)
GaussianMixture::Sum expectedSum{ GaussianFactorGraphTree expectedSum{
M(0), M(0),
{GaussianFactorGraph(std::vector<GF>{mixture->factor(d0), prior}), {GaussianFactorGraph(std::vector<GF>{mixture->factor(d0), prior}),
mixture->constant(d0)}, mixture->constant(d0)},