name change of Sum to GaussianFactorGraphTree and SumFrontals to assembleGraphTree

release/4.3a0
Frank Dellaert 2023-01-02 13:11:22 -05:00
parent 06aed5397a
commit f8d75abfeb
10 changed files with 118 additions and 98 deletions

View File

@ -51,28 +51,28 @@ GaussianMixture::GaussianMixture(
Conditionals(discreteParents, conditionalsList)) {}
/* *******************************************************************************/
GaussianMixture::Sum GaussianMixture::add(
const GaussianMixture::Sum &sum) const {
using Y = GaussianMixtureFactor::GraphAndConstant;
GaussianFactorGraphTree GaussianMixture::add(
const GaussianFactorGraphTree &sum) const {
using Y = GraphAndConstant;
auto add = [](const Y &graph1, const Y &graph2) {
auto result = graph1.graph;
result.push_back(graph2.graph);
return Y(result, graph1.constant + graph2.constant);
};
const Sum tree = asGaussianFactorGraphTree();
const auto tree = asGaussianFactorGraphTree();
return sum.empty() ? tree : sum.apply(tree, add);
}
/* *******************************************************************************/
GaussianMixture::Sum GaussianMixture::asGaussianFactorGraphTree() const {
GaussianFactorGraphTree GaussianMixture::asGaussianFactorGraphTree() const {
auto lambda = [](const GaussianConditional::shared_ptr &conditional) {
GaussianFactorGraph result;
result.push_back(conditional);
if (conditional) {
return GaussianMixtureFactor::GraphAndConstant(
return GraphAndConstant(
result, conditional->logNormalizationConstant());
} else {
return GaussianMixtureFactor::GraphAndConstant(result, 0.0);
return GraphAndConstant(result, 0.0);
}
};
return {conditionals_, lambda};
@ -108,7 +108,7 @@ bool GaussianMixture::equals(const HybridFactor &lf, double tol) const {
// This will return false if either conditionals_ is empty or e->conditionals_
// is empty, but not if both are empty or both are not empty:
if (conditionals_.empty() ^ e->conditionals_.empty()) return false;
std::cout << "checking" << std::endl;
std::cout << "checking" << std::endl;
// Check the base and the factors:
return BaseFactor::equals(*e, tol) &&
conditionals_.equals(e->conditionals_,

View File

@ -59,9 +59,6 @@ class GTSAM_EXPORT GaussianMixture
using BaseFactor = HybridFactor;
using BaseConditional = Conditional<HybridFactor, GaussianMixture>;
/// Alias for DecisionTree of GaussianFactorGraphs
using Sum = DecisionTree<Key, GaussianMixtureFactor::GraphAndConstant>;
/// typedef for Decision Tree of Gaussian Conditionals
using Conditionals = DecisionTree<Key, GaussianConditional::shared_ptr>;
@ -71,7 +68,7 @@ class GTSAM_EXPORT GaussianMixture
/**
* @brief Convert a DecisionTree of factors into a DT of Gaussian FGs.
*/
Sum asGaussianFactorGraphTree() const;
GaussianFactorGraphTree asGaussianFactorGraphTree() const;
/**
* @brief Helper function to get the pruner functor.
@ -172,6 +169,16 @@ class GTSAM_EXPORT GaussianMixture
*/
double error(const HybridValues &values) const override;
// /// Calculate probability density for given values `x`.
// double evaluate(const HybridValues &values) const;
// /// Evaluate probability density, sugar.
// double operator()(const HybridValues &values) const { return
// evaluate(values); }
// /// Calculate log-density for given values `x`.
// double logDensity(const HybridValues &values) const;
/**
* @brief Prune the decision tree of Gaussian factors as per the discrete
* `decisionTree`.
@ -186,9 +193,9 @@ class GTSAM_EXPORT GaussianMixture
* maintaining the decision tree structure.
*
* @param sum Decision Tree of Gaussian Factor Graphs
* @return Sum
* @return GaussianFactorGraphTree
*/
Sum add(const Sum &sum) const;
GaussianFactorGraphTree add(const GaussianFactorGraphTree &sum) const;
/// @}
};

View File

@ -100,25 +100,25 @@ double GaussianMixtureFactor::constant(const DiscreteValues &assignment) const {
// }
/* *******************************************************************************/
GaussianMixtureFactor::Sum GaussianMixtureFactor::add(
const GaussianMixtureFactor::Sum &sum) const {
using Y = GaussianMixtureFactor::GraphAndConstant;
GaussianFactorGraphTree GaussianMixtureFactor::add(
const GaussianFactorGraphTree &sum) const {
using Y = GraphAndConstant;
auto add = [](const Y &graph1, const Y &graph2) {
auto result = graph1.graph;
result.push_back(graph2.graph);
return Y(result, graph1.constant + graph2.constant);
};
const Sum tree = asGaussianFactorGraphTree();
const auto tree = asGaussianFactorGraphTree();
return sum.empty() ? tree : sum.apply(tree, add);
}
/* *******************************************************************************/
GaussianMixtureFactor::Sum GaussianMixtureFactor::asGaussianFactorGraphTree()
GaussianFactorGraphTree GaussianMixtureFactor::asGaussianFactorGraphTree()
const {
auto wrap = [](const FactorAndConstant &factor_z) {
GaussianFactorGraph result;
result.push_back(factor_z.factor);
return GaussianMixtureFactor::GraphAndConstant(result, factor_z.constant);
return GraphAndConstant(result, factor_z.constant);
};
return {factors_, wrap};
}

View File

@ -72,35 +72,6 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
}
};
/// Gaussian factor graph and log of normalizing constant.
struct GraphAndConstant {
GaussianFactorGraph graph;
double constant;
GraphAndConstant(const GaussianFactorGraph &graph, double constant)
: graph(graph), constant(constant) {}
// Check pointer equality.
bool operator==(const GraphAndConstant &other) const {
return graph == other.graph && constant == other.constant;
}
// Implement GTSAM-style print:
void print(const std::string &s = "Graph: ",
const KeyFormatter &formatter = DefaultKeyFormatter) const {
graph.print(s, formatter);
std::cout << "Constant: " << constant << std::endl;
}
// Implement GTSAM-style equals:
bool equals(const GraphAndConstant &other, double tol = 1e-9) const {
return graph.equals(other.graph, tol) &&
fabs(constant - other.constant) < tol;
}
};
using Sum = DecisionTree<Key, GraphAndConstant>;
/// typedef for Decision Tree of Gaussian factors and log-constant.
using Factors = DecisionTree<Key, FactorAndConstant>;
using Mixture = DecisionTree<Key, sharedFactor>;
@ -113,9 +84,9 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
* @brief Helper function to return factors and functional to create a
* DecisionTree of Gaussian Factor Graphs.
*
* @return Sum (DecisionTree<Key, GaussianFactorGraph>)
* @return GaussianFactorGraphTree
*/
Sum asGaussianFactorGraphTree() const;
GaussianFactorGraphTree asGaussianFactorGraphTree() const;
public:
/// @name Constructors
@ -184,7 +155,7 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
* variables.
* @return Sum
*/
Sum add(const Sum &sum) const;
GaussianFactorGraphTree add(const GaussianFactorGraphTree &sum) const;
/**
* @brief Compute error of the GaussianMixtureFactor as a tree.
@ -202,7 +173,8 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
double error(const HybridValues &values) const override;
/// Add MixtureFactor to a Sum, syntactic sugar.
friend Sum &operator+=(Sum &sum, const GaussianMixtureFactor &factor) {
friend GaussianFactorGraphTree &operator+=(
GaussianFactorGraphTree &sum, const GaussianMixtureFactor &factor) {
sum = factor.add(sum);
return sum;
}
@ -215,7 +187,6 @@ struct traits<GaussianMixtureFactor> : public Testable<GaussianMixtureFactor> {
};
template <>
struct traits<GaussianMixtureFactor::GraphAndConstant>
: public Testable<GaussianMixtureFactor::GraphAndConstant> {};
struct traits<GraphAndConstant> : public Testable<GraphAndConstant> {};
} // namespace gtsam

View File

@ -279,18 +279,31 @@ double HybridBayesNet::evaluate(const HybridValues &values) const {
const VectorValues &continuousValues = values.continuous();
double logDensity = 0.0, probability = 1.0;
bool debug = false;
// Iterate over each conditional.
for (auto &&conditional : *this) {
// TODO: should be delegated to derived classes.
if (auto gm = conditional->asMixture()) {
const auto component = (*gm)(discreteValues);
logDensity += component->logDensity(continuousValues);
if (debug) {
GTSAM_PRINT(continuousValues);
std::cout << "component->logDensity(continuousValues) = "
<< component->logDensity(continuousValues) << std::endl;
}
} else if (auto gc = conditional->asGaussian()) {
// If continuous only, evaluate the probability and multiply.
logDensity += gc->logDensity(continuousValues);
if (debug)
std::cout << "gc->logDensity(continuousValues) = "
<< gc->logDensity(continuousValues) << std::endl;
} else if (auto dc = conditional->asDiscrete()) {
// Conditional is discrete-only, so return its probability.
probability *= dc->operator()(discreteValues);
if (debug)
std::cout << "dc->operator()(discreteValues) = "
<< dc->operator()(discreteValues) << std::endl;
}
}

View File

@ -21,6 +21,8 @@
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/inference/Factor.h>
#include <gtsam/nonlinear/Values.h>
#include <gtsam/linear/GaussianFactorGraph.h>
#include <gtsam/discrete/DecisionTree.h>
#include <cstddef>
#include <string>
@ -28,6 +30,36 @@ namespace gtsam {
class HybridValues;
/// Gaussian factor graph and log of normalizing constant.
struct GraphAndConstant {
GaussianFactorGraph graph;
double constant;
GraphAndConstant(const GaussianFactorGraph &graph, double constant)
: graph(graph), constant(constant) {}
// Check pointer equality.
bool operator==(const GraphAndConstant &other) const {
return graph == other.graph && constant == other.constant;
}
// Implement GTSAM-style print:
void print(const std::string &s = "Graph: ",
const KeyFormatter &formatter = DefaultKeyFormatter) const {
graph.print(s, formatter);
std::cout << "Constant: " << constant << std::endl;
}
// Implement GTSAM-style equals:
bool equals(const GraphAndConstant &other, double tol = 1e-9) const {
return graph.equals(other.graph, tol) &&
fabs(constant - other.constant) < tol;
}
};
/// Alias for DecisionTree of GaussianFactorGraphs
using GaussianFactorGraphTree = DecisionTree<Key, GraphAndConstant>;
KeyVector CollectKeys(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys);
KeyVector CollectKeys(const KeyVector &keys1, const KeyVector &keys2);

View File

@ -59,51 +59,44 @@ namespace gtsam {
template class EliminateableFactorGraph<HybridGaussianFactorGraph>;
/* ************************************************************************ */
static GaussianMixtureFactor::Sum addGaussian(
const GaussianMixtureFactor::Sum &sum,
static GaussianFactorGraphTree addGaussian(
const GaussianFactorGraphTree &sum,
const GaussianFactor::shared_ptr &factor) {
// If the decision tree is not initialized, then initialize it.
if (sum.empty()) {
GaussianFactorGraph result;
result.push_back(factor);
return GaussianMixtureFactor::Sum(
GaussianMixtureFactor::GraphAndConstant(result, 0.0));
return GaussianFactorGraphTree(GraphAndConstant(result, 0.0));
} else {
auto add = [&factor](
const GaussianMixtureFactor::GraphAndConstant &graph_z) {
auto add = [&factor](const GraphAndConstant &graph_z) {
auto result = graph_z.graph;
result.push_back(factor);
return GaussianMixtureFactor::GraphAndConstant(result, graph_z.constant);
return GraphAndConstant(result, graph_z.constant);
};
return sum.apply(add);
}
}
/* ************************************************************************ */
// TODO(dellaert): At the time I though "Sum" was a good name, but coming back
// to it after a while I think "SumFrontals" is better.it's a terrible name.
// Also, the implementation is inconsistent. I think we should just have a
// virtual method in HybridFactor that adds to the "Sum" object, like
// addGaussian. Finally, we need to document why deferredFactors need to be
// added last, which I would undo if possible.
// Implementation-wise, it's probably more efficient to first collect the
// discrete keys, and then loop over all assignments to populate a vector.
GaussianMixtureFactor::Sum HybridGaussianFactorGraph::SumFrontals() const {
// sum out frontals, this is the factor on the separator
gttic(sum);
// TODO(dellaert): We need to document why deferredFactors need to be
// added last, which I would undo if possible. Implementation-wise, it's
// probably more efficient to first collect the discrete keys, and then loop
// over all assignments to populate a vector.
GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
gttic(assembleGraphTree);
GaussianMixtureFactor::Sum sum;
GaussianFactorGraphTree result;
std::vector<GaussianFactor::shared_ptr> deferredFactors;
for (auto &f : factors_) {
// TODO(dellaert): just use a virtual method defined in HybridFactor.
if (f->isHybrid()) {
if (auto gm = boost::dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
sum = gm->add(sum);
result = gm->add(result);
}
if (auto gm = boost::dynamic_pointer_cast<HybridConditional>(f)) {
sum = gm->asMixture()->add(sum);
result = gm->asMixture()->add(result);
}
} else if (f->isContinuous()) {
@ -134,12 +127,12 @@ GaussianMixtureFactor::Sum HybridGaussianFactorGraph::SumFrontals() const {
}
for (auto &f : deferredFactors) {
sum = addGaussian(sum, f);
result = addGaussian(result, f);
}
gttoc(sum);
gttoc(assembleGraphTree);
return sum;
return result;
}
/* ************************************************************************ */
@ -192,18 +185,14 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
// If any GaussianFactorGraph in the decision tree contains a nullptr, convert
// that leaf to an empty GaussianFactorGraph. Needed since the DecisionTree will
// otherwise create a GFG with a single (null) factor.
GaussianMixtureFactor::Sum removeEmpty(const GaussianMixtureFactor::Sum &sum) {
auto emptyGaussian = [](const GaussianMixtureFactor::GraphAndConstant
&graph_z) {
GaussianFactorGraphTree removeEmpty(const GaussianFactorGraphTree &sum) {
auto emptyGaussian = [](const GraphAndConstant &graph_z) {
bool hasNull =
std::any_of(graph_z.graph.begin(), graph_z.graph.end(),
[](const GaussianFactor::shared_ptr &ptr) { return !ptr; });
return hasNull
? GaussianMixtureFactor::GraphAndConstant{GaussianFactorGraph(),
0.0}
: graph_z;
return hasNull ? GraphAndConstant{GaussianFactorGraph(), 0.0} : graph_z;
};
return GaussianMixtureFactor::Sum(sum, emptyGaussian);
return GaussianFactorGraphTree(sum, emptyGaussian);
}
/* ************************************************************************ */
static std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr>
@ -218,17 +207,16 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
// Collect all the factors to create a set of Gaussian factor graphs in a
// decision tree indexed by all discrete keys involved.
GaussianMixtureFactor::Sum sum = factors.SumFrontals();
GaussianFactorGraphTree sum = factors.assembleGraphTree();
// TODO(dellaert): does SumFrontals not guarantee we do not need this?
// TODO(dellaert): does assembleGraphTree not guarantee we do not need this?
sum = removeEmpty(sum);
using EliminationPair = std::pair<boost::shared_ptr<GaussianConditional>,
GaussianMixtureFactor::FactorAndConstant>;
// This is the elimination method on the leaf nodes
auto eliminate = [&](const GaussianMixtureFactor::GraphAndConstant &graph_z)
-> EliminationPair {
auto eliminate = [&](const GraphAndConstant &graph_z) -> EliminationPair {
if (graph_z.graph.empty()) {
return {nullptr, {nullptr, 0.0}};
}
@ -247,7 +235,15 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
#endif
// Get the log of the log normalization constant inverse.
double logZ = graph_z.constant - conditional->logNormalizationConstant();
double logZ = -conditional->logNormalizationConstant();
// IF this is the last continuous variable to eliminated, we need to
// calculate the error here: the value of all factors at the mean, see
// ml_map_rao.pdf.
if (continuousSeparator.empty()) {
const auto posterior_mean = conditional->solve(VectorValues());
logZ += graph_z.graph.error(posterior_mean);
}
return {conditional, {newFactor, logZ}};
};

View File

@ -245,7 +245,7 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
* one for A and one for B. The leaves of the tree will be the Gaussian
* factors that have only continuous keys.
*/
GaussianMixtureFactor::Sum SumFrontals() const;
GaussianFactorGraphTree assembleGraphTree() const;
/// @}
};

View File

@ -80,7 +80,7 @@ TEST(GaussianMixtureFactor, Sum) {
// Create sum of two mixture factors: it will be a decision tree now on both
// discrete variables m1 and m2:
GaussianMixtureFactor::Sum sum;
GaussianFactorGraphTree sum;
sum += mixtureFactorA;
sum += mixtureFactorB;

View File

@ -615,15 +615,16 @@ TEST(HybridGaussianFactorGraph, ErrorAndProbPrimeTree) {
}
/* ****************************************************************************/
// Check that SumFrontals assembles Gaussian factor graphs for each assignment.
TEST(HybridGaussianFactorGraph, SumFrontals) {
// Check that assembleGraphTree assembles Gaussian factor graphs for each
// assignment.
TEST(HybridGaussianFactorGraph, assembleGraphTree) {
const int num_measurements = 1;
const bool deterministic = true;
auto fg =
tiny::createHybridGaussianFactorGraph(num_measurements, deterministic);
EXPECT_LONGS_EQUAL(3, fg.size());
auto sum = fg.SumFrontals();
auto sum = fg.assembleGraphTree();
// Get mixture factor:
auto mixture = boost::dynamic_pointer_cast<GaussianMixtureFactor>(fg.at(0));
@ -638,7 +639,7 @@ TEST(HybridGaussianFactorGraph, SumFrontals) {
// Expected decision tree with two factor graphs:
// f(x0;mode=0)P(x0) and f(x0;mode=1)P(x0)
GaussianMixture::Sum expectedSum{
GaussianFactorGraphTree expectedSum{
M(0),
{GaussianFactorGraph(std::vector<GF>{mixture->factor(d0), prior}),
mixture->constant(d0)},