name change of Sum to GaussianFactorGraphTree and SumFrontals to assembleGraphTree
parent
06aed5397a
commit
f8d75abfeb
|
|
@ -51,28 +51,28 @@ GaussianMixture::GaussianMixture(
|
||||||
Conditionals(discreteParents, conditionalsList)) {}
|
Conditionals(discreteParents, conditionalsList)) {}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
GaussianMixture::Sum GaussianMixture::add(
|
GaussianFactorGraphTree GaussianMixture::add(
|
||||||
const GaussianMixture::Sum &sum) const {
|
const GaussianFactorGraphTree &sum) const {
|
||||||
using Y = GaussianMixtureFactor::GraphAndConstant;
|
using Y = GraphAndConstant;
|
||||||
auto add = [](const Y &graph1, const Y &graph2) {
|
auto add = [](const Y &graph1, const Y &graph2) {
|
||||||
auto result = graph1.graph;
|
auto result = graph1.graph;
|
||||||
result.push_back(graph2.graph);
|
result.push_back(graph2.graph);
|
||||||
return Y(result, graph1.constant + graph2.constant);
|
return Y(result, graph1.constant + graph2.constant);
|
||||||
};
|
};
|
||||||
const Sum tree = asGaussianFactorGraphTree();
|
const auto tree = asGaussianFactorGraphTree();
|
||||||
return sum.empty() ? tree : sum.apply(tree, add);
|
return sum.empty() ? tree : sum.apply(tree, add);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
GaussianMixture::Sum GaussianMixture::asGaussianFactorGraphTree() const {
|
GaussianFactorGraphTree GaussianMixture::asGaussianFactorGraphTree() const {
|
||||||
auto lambda = [](const GaussianConditional::shared_ptr &conditional) {
|
auto lambda = [](const GaussianConditional::shared_ptr &conditional) {
|
||||||
GaussianFactorGraph result;
|
GaussianFactorGraph result;
|
||||||
result.push_back(conditional);
|
result.push_back(conditional);
|
||||||
if (conditional) {
|
if (conditional) {
|
||||||
return GaussianMixtureFactor::GraphAndConstant(
|
return GraphAndConstant(
|
||||||
result, conditional->logNormalizationConstant());
|
result, conditional->logNormalizationConstant());
|
||||||
} else {
|
} else {
|
||||||
return GaussianMixtureFactor::GraphAndConstant(result, 0.0);
|
return GraphAndConstant(result, 0.0);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
return {conditionals_, lambda};
|
return {conditionals_, lambda};
|
||||||
|
|
|
||||||
|
|
@ -59,9 +59,6 @@ class GTSAM_EXPORT GaussianMixture
|
||||||
using BaseFactor = HybridFactor;
|
using BaseFactor = HybridFactor;
|
||||||
using BaseConditional = Conditional<HybridFactor, GaussianMixture>;
|
using BaseConditional = Conditional<HybridFactor, GaussianMixture>;
|
||||||
|
|
||||||
/// Alias for DecisionTree of GaussianFactorGraphs
|
|
||||||
using Sum = DecisionTree<Key, GaussianMixtureFactor::GraphAndConstant>;
|
|
||||||
|
|
||||||
/// typedef for Decision Tree of Gaussian Conditionals
|
/// typedef for Decision Tree of Gaussian Conditionals
|
||||||
using Conditionals = DecisionTree<Key, GaussianConditional::shared_ptr>;
|
using Conditionals = DecisionTree<Key, GaussianConditional::shared_ptr>;
|
||||||
|
|
||||||
|
|
@ -71,7 +68,7 @@ class GTSAM_EXPORT GaussianMixture
|
||||||
/**
|
/**
|
||||||
* @brief Convert a DecisionTree of factors into a DT of Gaussian FGs.
|
* @brief Convert a DecisionTree of factors into a DT of Gaussian FGs.
|
||||||
*/
|
*/
|
||||||
Sum asGaussianFactorGraphTree() const;
|
GaussianFactorGraphTree asGaussianFactorGraphTree() const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Helper function to get the pruner functor.
|
* @brief Helper function to get the pruner functor.
|
||||||
|
|
@ -172,6 +169,16 @@ class GTSAM_EXPORT GaussianMixture
|
||||||
*/
|
*/
|
||||||
double error(const HybridValues &values) const override;
|
double error(const HybridValues &values) const override;
|
||||||
|
|
||||||
|
// /// Calculate probability density for given values `x`.
|
||||||
|
// double evaluate(const HybridValues &values) const;
|
||||||
|
|
||||||
|
// /// Evaluate probability density, sugar.
|
||||||
|
// double operator()(const HybridValues &values) const { return
|
||||||
|
// evaluate(values); }
|
||||||
|
|
||||||
|
// /// Calculate log-density for given values `x`.
|
||||||
|
// double logDensity(const HybridValues &values) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Prune the decision tree of Gaussian factors as per the discrete
|
* @brief Prune the decision tree of Gaussian factors as per the discrete
|
||||||
* `decisionTree`.
|
* `decisionTree`.
|
||||||
|
|
@ -186,9 +193,9 @@ class GTSAM_EXPORT GaussianMixture
|
||||||
* maintaining the decision tree structure.
|
* maintaining the decision tree structure.
|
||||||
*
|
*
|
||||||
* @param sum Decision Tree of Gaussian Factor Graphs
|
* @param sum Decision Tree of Gaussian Factor Graphs
|
||||||
* @return Sum
|
* @return GaussianFactorGraphTree
|
||||||
*/
|
*/
|
||||||
Sum add(const Sum &sum) const;
|
GaussianFactorGraphTree add(const GaussianFactorGraphTree &sum) const;
|
||||||
/// @}
|
/// @}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -100,25 +100,25 @@ double GaussianMixtureFactor::constant(const DiscreteValues &assignment) const {
|
||||||
// }
|
// }
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
GaussianMixtureFactor::Sum GaussianMixtureFactor::add(
|
GaussianFactorGraphTree GaussianMixtureFactor::add(
|
||||||
const GaussianMixtureFactor::Sum &sum) const {
|
const GaussianFactorGraphTree &sum) const {
|
||||||
using Y = GaussianMixtureFactor::GraphAndConstant;
|
using Y = GraphAndConstant;
|
||||||
auto add = [](const Y &graph1, const Y &graph2) {
|
auto add = [](const Y &graph1, const Y &graph2) {
|
||||||
auto result = graph1.graph;
|
auto result = graph1.graph;
|
||||||
result.push_back(graph2.graph);
|
result.push_back(graph2.graph);
|
||||||
return Y(result, graph1.constant + graph2.constant);
|
return Y(result, graph1.constant + graph2.constant);
|
||||||
};
|
};
|
||||||
const Sum tree = asGaussianFactorGraphTree();
|
const auto tree = asGaussianFactorGraphTree();
|
||||||
return sum.empty() ? tree : sum.apply(tree, add);
|
return sum.empty() ? tree : sum.apply(tree, add);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
GaussianMixtureFactor::Sum GaussianMixtureFactor::asGaussianFactorGraphTree()
|
GaussianFactorGraphTree GaussianMixtureFactor::asGaussianFactorGraphTree()
|
||||||
const {
|
const {
|
||||||
auto wrap = [](const FactorAndConstant &factor_z) {
|
auto wrap = [](const FactorAndConstant &factor_z) {
|
||||||
GaussianFactorGraph result;
|
GaussianFactorGraph result;
|
||||||
result.push_back(factor_z.factor);
|
result.push_back(factor_z.factor);
|
||||||
return GaussianMixtureFactor::GraphAndConstant(result, factor_z.constant);
|
return GraphAndConstant(result, factor_z.constant);
|
||||||
};
|
};
|
||||||
return {factors_, wrap};
|
return {factors_, wrap};
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -72,35 +72,6 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Gaussian factor graph and log of normalizing constant.
|
|
||||||
struct GraphAndConstant {
|
|
||||||
GaussianFactorGraph graph;
|
|
||||||
double constant;
|
|
||||||
|
|
||||||
GraphAndConstant(const GaussianFactorGraph &graph, double constant)
|
|
||||||
: graph(graph), constant(constant) {}
|
|
||||||
|
|
||||||
// Check pointer equality.
|
|
||||||
bool operator==(const GraphAndConstant &other) const {
|
|
||||||
return graph == other.graph && constant == other.constant;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Implement GTSAM-style print:
|
|
||||||
void print(const std::string &s = "Graph: ",
|
|
||||||
const KeyFormatter &formatter = DefaultKeyFormatter) const {
|
|
||||||
graph.print(s, formatter);
|
|
||||||
std::cout << "Constant: " << constant << std::endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Implement GTSAM-style equals:
|
|
||||||
bool equals(const GraphAndConstant &other, double tol = 1e-9) const {
|
|
||||||
return graph.equals(other.graph, tol) &&
|
|
||||||
fabs(constant - other.constant) < tol;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
using Sum = DecisionTree<Key, GraphAndConstant>;
|
|
||||||
|
|
||||||
/// typedef for Decision Tree of Gaussian factors and log-constant.
|
/// typedef for Decision Tree of Gaussian factors and log-constant.
|
||||||
using Factors = DecisionTree<Key, FactorAndConstant>;
|
using Factors = DecisionTree<Key, FactorAndConstant>;
|
||||||
using Mixture = DecisionTree<Key, sharedFactor>;
|
using Mixture = DecisionTree<Key, sharedFactor>;
|
||||||
|
|
@ -113,9 +84,9 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
|
||||||
* @brief Helper function to return factors and functional to create a
|
* @brief Helper function to return factors and functional to create a
|
||||||
* DecisionTree of Gaussian Factor Graphs.
|
* DecisionTree of Gaussian Factor Graphs.
|
||||||
*
|
*
|
||||||
* @return Sum (DecisionTree<Key, GaussianFactorGraph>)
|
* @return GaussianFactorGraphTree
|
||||||
*/
|
*/
|
||||||
Sum asGaussianFactorGraphTree() const;
|
GaussianFactorGraphTree asGaussianFactorGraphTree() const;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
/// @name Constructors
|
/// @name Constructors
|
||||||
|
|
@ -184,7 +155,7 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
|
||||||
* variables.
|
* variables.
|
||||||
* @return Sum
|
* @return Sum
|
||||||
*/
|
*/
|
||||||
Sum add(const Sum &sum) const;
|
GaussianFactorGraphTree add(const GaussianFactorGraphTree &sum) const;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Compute error of the GaussianMixtureFactor as a tree.
|
* @brief Compute error of the GaussianMixtureFactor as a tree.
|
||||||
|
|
@ -202,7 +173,8 @@ class GTSAM_EXPORT GaussianMixtureFactor : public HybridFactor {
|
||||||
double error(const HybridValues &values) const override;
|
double error(const HybridValues &values) const override;
|
||||||
|
|
||||||
/// Add MixtureFactor to a Sum, syntactic sugar.
|
/// Add MixtureFactor to a Sum, syntactic sugar.
|
||||||
friend Sum &operator+=(Sum &sum, const GaussianMixtureFactor &factor) {
|
friend GaussianFactorGraphTree &operator+=(
|
||||||
|
GaussianFactorGraphTree &sum, const GaussianMixtureFactor &factor) {
|
||||||
sum = factor.add(sum);
|
sum = factor.add(sum);
|
||||||
return sum;
|
return sum;
|
||||||
}
|
}
|
||||||
|
|
@ -215,7 +187,6 @@ struct traits<GaussianMixtureFactor> : public Testable<GaussianMixtureFactor> {
|
||||||
};
|
};
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
struct traits<GaussianMixtureFactor::GraphAndConstant>
|
struct traits<GraphAndConstant> : public Testable<GraphAndConstant> {};
|
||||||
: public Testable<GaussianMixtureFactor::GraphAndConstant> {};
|
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
||||||
|
|
@ -279,18 +279,31 @@ double HybridBayesNet::evaluate(const HybridValues &values) const {
|
||||||
const VectorValues &continuousValues = values.continuous();
|
const VectorValues &continuousValues = values.continuous();
|
||||||
|
|
||||||
double logDensity = 0.0, probability = 1.0;
|
double logDensity = 0.0, probability = 1.0;
|
||||||
|
bool debug = false;
|
||||||
|
|
||||||
// Iterate over each conditional.
|
// Iterate over each conditional.
|
||||||
for (auto &&conditional : *this) {
|
for (auto &&conditional : *this) {
|
||||||
|
// TODO: should be delegated to derived classes.
|
||||||
if (auto gm = conditional->asMixture()) {
|
if (auto gm = conditional->asMixture()) {
|
||||||
const auto component = (*gm)(discreteValues);
|
const auto component = (*gm)(discreteValues);
|
||||||
logDensity += component->logDensity(continuousValues);
|
logDensity += component->logDensity(continuousValues);
|
||||||
|
if (debug) {
|
||||||
|
GTSAM_PRINT(continuousValues);
|
||||||
|
std::cout << "component->logDensity(continuousValues) = "
|
||||||
|
<< component->logDensity(continuousValues) << std::endl;
|
||||||
|
}
|
||||||
} else if (auto gc = conditional->asGaussian()) {
|
} else if (auto gc = conditional->asGaussian()) {
|
||||||
// If continuous only, evaluate the probability and multiply.
|
// If continuous only, evaluate the probability and multiply.
|
||||||
logDensity += gc->logDensity(continuousValues);
|
logDensity += gc->logDensity(continuousValues);
|
||||||
|
if (debug)
|
||||||
|
std::cout << "gc->logDensity(continuousValues) = "
|
||||||
|
<< gc->logDensity(continuousValues) << std::endl;
|
||||||
} else if (auto dc = conditional->asDiscrete()) {
|
} else if (auto dc = conditional->asDiscrete()) {
|
||||||
// Conditional is discrete-only, so return its probability.
|
// Conditional is discrete-only, so return its probability.
|
||||||
probability *= dc->operator()(discreteValues);
|
probability *= dc->operator()(discreteValues);
|
||||||
|
if (debug)
|
||||||
|
std::cout << "dc->operator()(discreteValues) = "
|
||||||
|
<< dc->operator()(discreteValues) << std::endl;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,8 @@
|
||||||
#include <gtsam/discrete/DiscreteKey.h>
|
#include <gtsam/discrete/DiscreteKey.h>
|
||||||
#include <gtsam/inference/Factor.h>
|
#include <gtsam/inference/Factor.h>
|
||||||
#include <gtsam/nonlinear/Values.h>
|
#include <gtsam/nonlinear/Values.h>
|
||||||
|
#include <gtsam/linear/GaussianFactorGraph.h>
|
||||||
|
#include <gtsam/discrete/DecisionTree.h>
|
||||||
|
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
@ -28,6 +30,36 @@ namespace gtsam {
|
||||||
|
|
||||||
class HybridValues;
|
class HybridValues;
|
||||||
|
|
||||||
|
/// Gaussian factor graph and log of normalizing constant.
|
||||||
|
struct GraphAndConstant {
|
||||||
|
GaussianFactorGraph graph;
|
||||||
|
double constant;
|
||||||
|
|
||||||
|
GraphAndConstant(const GaussianFactorGraph &graph, double constant)
|
||||||
|
: graph(graph), constant(constant) {}
|
||||||
|
|
||||||
|
// Check pointer equality.
|
||||||
|
bool operator==(const GraphAndConstant &other) const {
|
||||||
|
return graph == other.graph && constant == other.constant;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implement GTSAM-style print:
|
||||||
|
void print(const std::string &s = "Graph: ",
|
||||||
|
const KeyFormatter &formatter = DefaultKeyFormatter) const {
|
||||||
|
graph.print(s, formatter);
|
||||||
|
std::cout << "Constant: " << constant << std::endl;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implement GTSAM-style equals:
|
||||||
|
bool equals(const GraphAndConstant &other, double tol = 1e-9) const {
|
||||||
|
return graph.equals(other.graph, tol) &&
|
||||||
|
fabs(constant - other.constant) < tol;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Alias for DecisionTree of GaussianFactorGraphs
|
||||||
|
using GaussianFactorGraphTree = DecisionTree<Key, GraphAndConstant>;
|
||||||
|
|
||||||
KeyVector CollectKeys(const KeyVector &continuousKeys,
|
KeyVector CollectKeys(const KeyVector &continuousKeys,
|
||||||
const DiscreteKeys &discreteKeys);
|
const DiscreteKeys &discreteKeys);
|
||||||
KeyVector CollectKeys(const KeyVector &keys1, const KeyVector &keys2);
|
KeyVector CollectKeys(const KeyVector &keys1, const KeyVector &keys2);
|
||||||
|
|
|
||||||
|
|
@ -59,51 +59,44 @@ namespace gtsam {
|
||||||
template class EliminateableFactorGraph<HybridGaussianFactorGraph>;
|
template class EliminateableFactorGraph<HybridGaussianFactorGraph>;
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
static GaussianMixtureFactor::Sum addGaussian(
|
static GaussianFactorGraphTree addGaussian(
|
||||||
const GaussianMixtureFactor::Sum &sum,
|
const GaussianFactorGraphTree &sum,
|
||||||
const GaussianFactor::shared_ptr &factor) {
|
const GaussianFactor::shared_ptr &factor) {
|
||||||
// If the decision tree is not initialized, then initialize it.
|
// If the decision tree is not initialized, then initialize it.
|
||||||
if (sum.empty()) {
|
if (sum.empty()) {
|
||||||
GaussianFactorGraph result;
|
GaussianFactorGraph result;
|
||||||
result.push_back(factor);
|
result.push_back(factor);
|
||||||
return GaussianMixtureFactor::Sum(
|
return GaussianFactorGraphTree(GraphAndConstant(result, 0.0));
|
||||||
GaussianMixtureFactor::GraphAndConstant(result, 0.0));
|
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
auto add = [&factor](
|
auto add = [&factor](const GraphAndConstant &graph_z) {
|
||||||
const GaussianMixtureFactor::GraphAndConstant &graph_z) {
|
|
||||||
auto result = graph_z.graph;
|
auto result = graph_z.graph;
|
||||||
result.push_back(factor);
|
result.push_back(factor);
|
||||||
return GaussianMixtureFactor::GraphAndConstant(result, graph_z.constant);
|
return GraphAndConstant(result, graph_z.constant);
|
||||||
};
|
};
|
||||||
return sum.apply(add);
|
return sum.apply(add);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
// TODO(dellaert): At the time I though "Sum" was a good name, but coming back
|
// TODO(dellaert): We need to document why deferredFactors need to be
|
||||||
// to it after a while I think "SumFrontals" is better.it's a terrible name.
|
// added last, which I would undo if possible. Implementation-wise, it's
|
||||||
// Also, the implementation is inconsistent. I think we should just have a
|
// probably more efficient to first collect the discrete keys, and then loop
|
||||||
// virtual method in HybridFactor that adds to the "Sum" object, like
|
// over all assignments to populate a vector.
|
||||||
// addGaussian. Finally, we need to document why deferredFactors need to be
|
GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
|
||||||
// added last, which I would undo if possible.
|
gttic(assembleGraphTree);
|
||||||
// Implementation-wise, it's probably more efficient to first collect the
|
|
||||||
// discrete keys, and then loop over all assignments to populate a vector.
|
|
||||||
GaussianMixtureFactor::Sum HybridGaussianFactorGraph::SumFrontals() const {
|
|
||||||
// sum out frontals, this is the factor on the separator
|
|
||||||
gttic(sum);
|
|
||||||
|
|
||||||
GaussianMixtureFactor::Sum sum;
|
GaussianFactorGraphTree result;
|
||||||
std::vector<GaussianFactor::shared_ptr> deferredFactors;
|
std::vector<GaussianFactor::shared_ptr> deferredFactors;
|
||||||
|
|
||||||
for (auto &f : factors_) {
|
for (auto &f : factors_) {
|
||||||
// TODO(dellaert): just use a virtual method defined in HybridFactor.
|
// TODO(dellaert): just use a virtual method defined in HybridFactor.
|
||||||
if (f->isHybrid()) {
|
if (f->isHybrid()) {
|
||||||
if (auto gm = boost::dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
|
if (auto gm = boost::dynamic_pointer_cast<GaussianMixtureFactor>(f)) {
|
||||||
sum = gm->add(sum);
|
result = gm->add(result);
|
||||||
}
|
}
|
||||||
if (auto gm = boost::dynamic_pointer_cast<HybridConditional>(f)) {
|
if (auto gm = boost::dynamic_pointer_cast<HybridConditional>(f)) {
|
||||||
sum = gm->asMixture()->add(sum);
|
result = gm->asMixture()->add(result);
|
||||||
}
|
}
|
||||||
|
|
||||||
} else if (f->isContinuous()) {
|
} else if (f->isContinuous()) {
|
||||||
|
|
@ -134,12 +127,12 @@ GaussianMixtureFactor::Sum HybridGaussianFactorGraph::SumFrontals() const {
|
||||||
}
|
}
|
||||||
|
|
||||||
for (auto &f : deferredFactors) {
|
for (auto &f : deferredFactors) {
|
||||||
sum = addGaussian(sum, f);
|
result = addGaussian(result, f);
|
||||||
}
|
}
|
||||||
|
|
||||||
gttoc(sum);
|
gttoc(assembleGraphTree);
|
||||||
|
|
||||||
return sum;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
|
|
@ -192,18 +185,14 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
|
||||||
// If any GaussianFactorGraph in the decision tree contains a nullptr, convert
|
// If any GaussianFactorGraph in the decision tree contains a nullptr, convert
|
||||||
// that leaf to an empty GaussianFactorGraph. Needed since the DecisionTree will
|
// that leaf to an empty GaussianFactorGraph. Needed since the DecisionTree will
|
||||||
// otherwise create a GFG with a single (null) factor.
|
// otherwise create a GFG with a single (null) factor.
|
||||||
GaussianMixtureFactor::Sum removeEmpty(const GaussianMixtureFactor::Sum &sum) {
|
GaussianFactorGraphTree removeEmpty(const GaussianFactorGraphTree &sum) {
|
||||||
auto emptyGaussian = [](const GaussianMixtureFactor::GraphAndConstant
|
auto emptyGaussian = [](const GraphAndConstant &graph_z) {
|
||||||
&graph_z) {
|
|
||||||
bool hasNull =
|
bool hasNull =
|
||||||
std::any_of(graph_z.graph.begin(), graph_z.graph.end(),
|
std::any_of(graph_z.graph.begin(), graph_z.graph.end(),
|
||||||
[](const GaussianFactor::shared_ptr &ptr) { return !ptr; });
|
[](const GaussianFactor::shared_ptr &ptr) { return !ptr; });
|
||||||
return hasNull
|
return hasNull ? GraphAndConstant{GaussianFactorGraph(), 0.0} : graph_z;
|
||||||
? GaussianMixtureFactor::GraphAndConstant{GaussianFactorGraph(),
|
|
||||||
0.0}
|
|
||||||
: graph_z;
|
|
||||||
};
|
};
|
||||||
return GaussianMixtureFactor::Sum(sum, emptyGaussian);
|
return GaussianFactorGraphTree(sum, emptyGaussian);
|
||||||
}
|
}
|
||||||
/* ************************************************************************ */
|
/* ************************************************************************ */
|
||||||
static std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr>
|
static std::pair<HybridConditional::shared_ptr, HybridFactor::shared_ptr>
|
||||||
|
|
@ -218,17 +207,16 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
|
||||||
|
|
||||||
// Collect all the factors to create a set of Gaussian factor graphs in a
|
// Collect all the factors to create a set of Gaussian factor graphs in a
|
||||||
// decision tree indexed by all discrete keys involved.
|
// decision tree indexed by all discrete keys involved.
|
||||||
GaussianMixtureFactor::Sum sum = factors.SumFrontals();
|
GaussianFactorGraphTree sum = factors.assembleGraphTree();
|
||||||
|
|
||||||
// TODO(dellaert): does SumFrontals not guarantee we do not need this?
|
// TODO(dellaert): does assembleGraphTree not guarantee we do not need this?
|
||||||
sum = removeEmpty(sum);
|
sum = removeEmpty(sum);
|
||||||
|
|
||||||
using EliminationPair = std::pair<boost::shared_ptr<GaussianConditional>,
|
using EliminationPair = std::pair<boost::shared_ptr<GaussianConditional>,
|
||||||
GaussianMixtureFactor::FactorAndConstant>;
|
GaussianMixtureFactor::FactorAndConstant>;
|
||||||
|
|
||||||
// This is the elimination method on the leaf nodes
|
// This is the elimination method on the leaf nodes
|
||||||
auto eliminate = [&](const GaussianMixtureFactor::GraphAndConstant &graph_z)
|
auto eliminate = [&](const GraphAndConstant &graph_z) -> EliminationPair {
|
||||||
-> EliminationPair {
|
|
||||||
if (graph_z.graph.empty()) {
|
if (graph_z.graph.empty()) {
|
||||||
return {nullptr, {nullptr, 0.0}};
|
return {nullptr, {nullptr, 0.0}};
|
||||||
}
|
}
|
||||||
|
|
@ -247,7 +235,15 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// Get the log of the log normalization constant inverse.
|
// Get the log of the log normalization constant inverse.
|
||||||
double logZ = graph_z.constant - conditional->logNormalizationConstant();
|
double logZ = -conditional->logNormalizationConstant();
|
||||||
|
|
||||||
|
// IF this is the last continuous variable to eliminated, we need to
|
||||||
|
// calculate the error here: the value of all factors at the mean, see
|
||||||
|
// ml_map_rao.pdf.
|
||||||
|
if (continuousSeparator.empty()) {
|
||||||
|
const auto posterior_mean = conditional->solve(VectorValues());
|
||||||
|
logZ += graph_z.graph.error(posterior_mean);
|
||||||
|
}
|
||||||
return {conditional, {newFactor, logZ}};
|
return {conditional, {newFactor, logZ}};
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -245,7 +245,7 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
|
||||||
* one for A and one for B. The leaves of the tree will be the Gaussian
|
* one for A and one for B. The leaves of the tree will be the Gaussian
|
||||||
* factors that have only continuous keys.
|
* factors that have only continuous keys.
|
||||||
*/
|
*/
|
||||||
GaussianMixtureFactor::Sum SumFrontals() const;
|
GaussianFactorGraphTree assembleGraphTree() const;
|
||||||
|
|
||||||
/// @}
|
/// @}
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -80,7 +80,7 @@ TEST(GaussianMixtureFactor, Sum) {
|
||||||
|
|
||||||
// Create sum of two mixture factors: it will be a decision tree now on both
|
// Create sum of two mixture factors: it will be a decision tree now on both
|
||||||
// discrete variables m1 and m2:
|
// discrete variables m1 and m2:
|
||||||
GaussianMixtureFactor::Sum sum;
|
GaussianFactorGraphTree sum;
|
||||||
sum += mixtureFactorA;
|
sum += mixtureFactorA;
|
||||||
sum += mixtureFactorB;
|
sum += mixtureFactorB;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -615,15 +615,16 @@ TEST(HybridGaussianFactorGraph, ErrorAndProbPrimeTree) {
|
||||||
}
|
}
|
||||||
|
|
||||||
/* ****************************************************************************/
|
/* ****************************************************************************/
|
||||||
// Check that SumFrontals assembles Gaussian factor graphs for each assignment.
|
// Check that assembleGraphTree assembles Gaussian factor graphs for each
|
||||||
TEST(HybridGaussianFactorGraph, SumFrontals) {
|
// assignment.
|
||||||
|
TEST(HybridGaussianFactorGraph, assembleGraphTree) {
|
||||||
const int num_measurements = 1;
|
const int num_measurements = 1;
|
||||||
const bool deterministic = true;
|
const bool deterministic = true;
|
||||||
auto fg =
|
auto fg =
|
||||||
tiny::createHybridGaussianFactorGraph(num_measurements, deterministic);
|
tiny::createHybridGaussianFactorGraph(num_measurements, deterministic);
|
||||||
EXPECT_LONGS_EQUAL(3, fg.size());
|
EXPECT_LONGS_EQUAL(3, fg.size());
|
||||||
|
|
||||||
auto sum = fg.SumFrontals();
|
auto sum = fg.assembleGraphTree();
|
||||||
|
|
||||||
// Get mixture factor:
|
// Get mixture factor:
|
||||||
auto mixture = boost::dynamic_pointer_cast<GaussianMixtureFactor>(fg.at(0));
|
auto mixture = boost::dynamic_pointer_cast<GaussianMixtureFactor>(fg.at(0));
|
||||||
|
|
@ -638,7 +639,7 @@ TEST(HybridGaussianFactorGraph, SumFrontals) {
|
||||||
|
|
||||||
// Expected decision tree with two factor graphs:
|
// Expected decision tree with two factor graphs:
|
||||||
// f(x0;mode=0)P(x0) and f(x0;mode=1)P(x0)
|
// f(x0;mode=0)P(x0) and f(x0;mode=1)P(x0)
|
||||||
GaussianMixture::Sum expectedSum{
|
GaussianFactorGraphTree expectedSum{
|
||||||
M(0),
|
M(0),
|
||||||
{GaussianFactorGraph(std::vector<GF>{mixture->factor(d0), prior}),
|
{GaussianFactorGraph(std::vector<GF>{mixture->factor(d0), prior}),
|
||||||
mixture->constant(d0)},
|
mixture->constant(d0)},
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue