Eradicated GraphAndConstant
parent
8357fc7e02
commit
e31884c9a1
|
|
@ -72,11 +72,11 @@ GaussianMixture::GaussianMixture(
|
||||||
// GaussianMixtureFactor, no?
|
// GaussianMixtureFactor, no?
|
||||||
GaussianFactorGraphTree GaussianMixture::add(
|
GaussianFactorGraphTree GaussianMixture::add(
|
||||||
const GaussianFactorGraphTree &sum) const {
|
const GaussianFactorGraphTree &sum) const {
|
||||||
using Y = GraphAndConstant;
|
using Y = GaussianFactorGraph;
|
||||||
auto add = [](const Y &graph1, const Y &graph2) {
|
auto add = [](const Y &graph1, const Y &graph2) {
|
||||||
auto result = graph1.graph;
|
auto result = graph1;
|
||||||
result.push_back(graph2.graph);
|
result.push_back(graph2);
|
||||||
return Y(result, graph1.constant + graph2.constant);
|
return result;
|
||||||
};
|
};
|
||||||
const auto tree = asGaussianFactorGraphTree();
|
const auto tree = asGaussianFactorGraphTree();
|
||||||
return sum.empty() ? tree : sum.apply(tree, add);
|
return sum.empty() ? tree : sum.apply(tree, add);
|
||||||
|
|
@ -84,16 +84,10 @@ GaussianFactorGraphTree GaussianMixture::add(
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
GaussianFactorGraphTree GaussianMixture::asGaussianFactorGraphTree() const {
|
GaussianFactorGraphTree GaussianMixture::asGaussianFactorGraphTree() const {
|
||||||
auto lambda = [](const GaussianConditional::shared_ptr &conditional) {
|
auto wrap = [](const GaussianConditional::shared_ptr &gc) {
|
||||||
GaussianFactorGraph result;
|
return GaussianFactorGraph{gc};
|
||||||
result.push_back(conditional);
|
|
||||||
if (conditional) {
|
|
||||||
return GraphAndConstant(result, conditional->logNormalizationConstant());
|
|
||||||
} else {
|
|
||||||
return GraphAndConstant(result, 0.0);
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
return {conditionals_, lambda};
|
return {conditionals_, wrap};
|
||||||
}
|
}
|
||||||
|
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
|
|
|
||||||
|
|
@ -84,11 +84,11 @@ GaussianMixtureFactor::sharedFactor GaussianMixtureFactor::operator()(
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
GaussianFactorGraphTree GaussianMixtureFactor::add(
|
GaussianFactorGraphTree GaussianMixtureFactor::add(
|
||||||
const GaussianFactorGraphTree &sum) const {
|
const GaussianFactorGraphTree &sum) const {
|
||||||
using Y = GraphAndConstant;
|
using Y = GaussianFactorGraph;
|
||||||
auto add = [](const Y &graph1, const Y &graph2) {
|
auto add = [](const Y &graph1, const Y &graph2) {
|
||||||
auto result = graph1.graph;
|
auto result = graph1;
|
||||||
result.push_back(graph2.graph);
|
result.push_back(graph2);
|
||||||
return Y(result, graph1.constant + graph2.constant);
|
return result;
|
||||||
};
|
};
|
||||||
const auto tree = asGaussianFactorGraphTree();
|
const auto tree = asGaussianFactorGraphTree();
|
||||||
return sum.empty() ? tree : sum.apply(tree, add);
|
return sum.empty() ? tree : sum.apply(tree, add);
|
||||||
|
|
@ -97,11 +97,7 @@ GaussianFactorGraphTree GaussianMixtureFactor::add(
|
||||||
/* *******************************************************************************/
|
/* *******************************************************************************/
|
||||||
GaussianFactorGraphTree GaussianMixtureFactor::asGaussianFactorGraphTree()
|
GaussianFactorGraphTree GaussianMixtureFactor::asGaussianFactorGraphTree()
|
||||||
const {
|
const {
|
||||||
auto wrap = [](const sharedFactor &gf) {
|
auto wrap = [](const sharedFactor &gf) { return GaussianFactorGraph{gf}; };
|
||||||
GaussianFactorGraph result;
|
|
||||||
result.push_back(gf);
|
|
||||||
return GraphAndConstant(result, 0.0);
|
|
||||||
};
|
|
||||||
return {factors_, wrap};
|
return {factors_, wrap};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -18,11 +18,11 @@
|
||||||
#pragma once
|
#pragma once
|
||||||
|
|
||||||
#include <gtsam/base/Testable.h>
|
#include <gtsam/base/Testable.h>
|
||||||
|
#include <gtsam/discrete/DecisionTree.h>
|
||||||
#include <gtsam/discrete/DiscreteKey.h>
|
#include <gtsam/discrete/DiscreteKey.h>
|
||||||
#include <gtsam/inference/Factor.h>
|
#include <gtsam/inference/Factor.h>
|
||||||
#include <gtsam/nonlinear/Values.h>
|
|
||||||
#include <gtsam/linear/GaussianFactorGraph.h>
|
#include <gtsam/linear/GaussianFactorGraph.h>
|
||||||
#include <gtsam/discrete/DecisionTree.h>
|
#include <gtsam/nonlinear/Values.h>
|
||||||
|
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
@ -30,35 +30,8 @@ namespace gtsam {
|
||||||
|
|
||||||
class HybridValues;
|
class HybridValues;
|
||||||
|
|
||||||
/// Gaussian factor graph and log of normalizing constant.
|
|
||||||
struct GraphAndConstant {
|
|
||||||
GaussianFactorGraph graph;
|
|
||||||
double constant;
|
|
||||||
|
|
||||||
GraphAndConstant(const GaussianFactorGraph &graph, double constant)
|
|
||||||
: graph(graph), constant(constant) {}
|
|
||||||
|
|
||||||
// Check pointer equality.
|
|
||||||
bool operator==(const GraphAndConstant &other) const {
|
|
||||||
return graph == other.graph && constant == other.constant;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Implement GTSAM-style print:
|
|
||||||
void print(const std::string &s = "Graph: ",
|
|
||||||
const KeyFormatter &formatter = DefaultKeyFormatter) const {
|
|
||||||
graph.print(s, formatter);
|
|
||||||
std::cout << "Constant: " << constant << std::endl;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Implement GTSAM-style equals:
|
|
||||||
bool equals(const GraphAndConstant &other, double tol = 1e-9) const {
|
|
||||||
return graph.equals(other.graph, tol) &&
|
|
||||||
fabs(constant - other.constant) < tol;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
/// Alias for DecisionTree of GaussianFactorGraphs
|
/// Alias for DecisionTree of GaussianFactorGraphs
|
||||||
using GaussianFactorGraphTree = DecisionTree<Key, GraphAndConstant>;
|
using GaussianFactorGraphTree = DecisionTree<Key, GaussianFactorGraph>;
|
||||||
|
|
||||||
KeyVector CollectKeys(const KeyVector &continuousKeys,
|
KeyVector CollectKeys(const KeyVector &continuousKeys,
|
||||||
const DiscreteKeys &discreteKeys);
|
const DiscreteKeys &discreteKeys);
|
||||||
|
|
@ -182,7 +155,4 @@ class GTSAM_EXPORT HybridFactor : public Factor {
|
||||||
template <>
|
template <>
|
||||||
struct traits<HybridFactor> : public Testable<HybridFactor> {};
|
struct traits<HybridFactor> : public Testable<HybridFactor> {};
|
||||||
|
|
||||||
template <>
|
|
||||||
struct traits<GraphAndConstant> : public Testable<GraphAndConstant> {};
|
|
||||||
|
|
||||||
} // namespace gtsam
|
} // namespace gtsam
|
||||||
|
|
|
||||||
|
|
@ -82,14 +82,13 @@ static GaussianFactorGraphTree addGaussian(
|
||||||
const GaussianFactor::shared_ptr &factor) {
|
const GaussianFactor::shared_ptr &factor) {
|
||||||
// If the decision tree is not initialized, then initialize it.
|
// If the decision tree is not initialized, then initialize it.
|
||||||
if (gfgTree.empty()) {
|
if (gfgTree.empty()) {
|
||||||
GaussianFactorGraph result;
|
GaussianFactorGraph result{factor};
|
||||||
result.push_back(factor);
|
return GaussianFactorGraphTree(result);
|
||||||
return GaussianFactorGraphTree(GraphAndConstant(result, 0.0));
|
|
||||||
} else {
|
} else {
|
||||||
auto add = [&factor](const GraphAndConstant &graph_z) {
|
auto add = [&factor](const GaussianFactorGraph &graph) {
|
||||||
auto result = graph_z.graph;
|
auto result = graph;
|
||||||
result.push_back(factor);
|
result.push_back(factor);
|
||||||
return GraphAndConstant(result, graph_z.constant);
|
return result;
|
||||||
};
|
};
|
||||||
return gfgTree.apply(add);
|
return gfgTree.apply(add);
|
||||||
}
|
}
|
||||||
|
|
@ -190,12 +189,13 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
|
||||||
// If any GaussianFactorGraph in the decision tree contains a nullptr, convert
|
// If any GaussianFactorGraph in the decision tree contains a nullptr, convert
|
||||||
// that leaf to an empty GaussianFactorGraph. Needed since the DecisionTree will
|
// that leaf to an empty GaussianFactorGraph. Needed since the DecisionTree will
|
||||||
// otherwise create a GFG with a single (null) factor.
|
// otherwise create a GFG with a single (null) factor.
|
||||||
|
// TODO(dellaert): still a mystery to me why this is needed.
|
||||||
GaussianFactorGraphTree removeEmpty(const GaussianFactorGraphTree &sum) {
|
GaussianFactorGraphTree removeEmpty(const GaussianFactorGraphTree &sum) {
|
||||||
auto emptyGaussian = [](const GraphAndConstant &graph_z) {
|
auto emptyGaussian = [](const GaussianFactorGraph &graph) {
|
||||||
bool hasNull =
|
bool hasNull =
|
||||||
std::any_of(graph_z.graph.begin(), graph_z.graph.end(),
|
std::any_of(graph.begin(), graph.end(),
|
||||||
[](const GaussianFactor::shared_ptr &ptr) { return !ptr; });
|
[](const GaussianFactor::shared_ptr &ptr) { return !ptr; });
|
||||||
return hasNull ? GraphAndConstant{GaussianFactorGraph(), 0.0} : graph_z;
|
return hasNull ? GaussianFactorGraph() : graph;
|
||||||
};
|
};
|
||||||
return GaussianFactorGraphTree(sum, emptyGaussian);
|
return GaussianFactorGraphTree(sum, emptyGaussian);
|
||||||
}
|
}
|
||||||
|
|
@ -224,8 +224,9 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
|
||||||
GaussianMixtureFactor::sharedFactor>;
|
GaussianMixtureFactor::sharedFactor>;
|
||||||
|
|
||||||
// This is the elimination method on the leaf nodes
|
// This is the elimination method on the leaf nodes
|
||||||
auto eliminateFunc = [&](const GraphAndConstant &graph_z) -> EliminationPair {
|
auto eliminateFunc =
|
||||||
if (graph_z.graph.empty()) {
|
[&](const GaussianFactorGraph &graph) -> EliminationPair {
|
||||||
|
if (graph.empty()) {
|
||||||
return {nullptr, nullptr};
|
return {nullptr, nullptr};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -236,12 +237,7 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
|
||||||
boost::shared_ptr<GaussianConditional> conditional;
|
boost::shared_ptr<GaussianConditional> conditional;
|
||||||
boost::shared_ptr<GaussianFactor> newFactor;
|
boost::shared_ptr<GaussianFactor> newFactor;
|
||||||
boost::tie(conditional, newFactor) =
|
boost::tie(conditional, newFactor) =
|
||||||
EliminatePreferCholesky(graph_z.graph, frontalKeys);
|
EliminatePreferCholesky(graph, frontalKeys);
|
||||||
|
|
||||||
// Get the log of the log normalization constant inverse and
|
|
||||||
// add it to the previous constant.
|
|
||||||
// const double logZ =
|
|
||||||
// graph_z.constant - conditional->logNormalizationConstant();
|
|
||||||
|
|
||||||
#ifdef HYBRID_TIMING
|
#ifdef HYBRID_TIMING
|
||||||
gttoc_(hybrid_eliminate);
|
gttoc_(hybrid_eliminate);
|
||||||
|
|
@ -271,15 +267,18 @@ hybridElimination(const HybridGaussianFactorGraph &factors,
|
||||||
// If there are no more continuous parents, then we should create a
|
// If there are no more continuous parents, then we should create a
|
||||||
// DiscreteFactor here, with the error for each discrete choice.
|
// DiscreteFactor here, with the error for each discrete choice.
|
||||||
if (continuousSeparator.empty()) {
|
if (continuousSeparator.empty()) {
|
||||||
auto probPrime = [&](const GaussianMixtureFactor::sharedFactor &factor) {
|
auto probPrime = [&](const EliminationPair &pair) {
|
||||||
// This is the unnormalized probability q(μ) at the mean.
|
// This is the unnormalized probability q(μ) at the mean.
|
||||||
// The factor has no keys, just contains the residual.
|
// The factor has no keys, just contains the residual.
|
||||||
static const VectorValues kEmpty;
|
static const VectorValues kEmpty;
|
||||||
return factor? exp(-factor->error(kEmpty)) : 1.0;
|
return pair.second ? exp(-pair.second->error(kEmpty)) /
|
||||||
|
pair.first->normalizationConstant()
|
||||||
|
: 1.0;
|
||||||
};
|
};
|
||||||
|
|
||||||
const auto discreteFactor = boost::make_shared<DecisionTreeFactor>(
|
const auto discreteFactor = boost::make_shared<DecisionTreeFactor>(
|
||||||
discreteSeparator, DecisionTree<Key, double>(newFactors, probPrime));
|
discreteSeparator,
|
||||||
|
DecisionTree<Key, double>(eliminationResults, probPrime));
|
||||||
|
|
||||||
return {boost::make_shared<HybridConditional>(gaussianMixture),
|
return {boost::make_shared<HybridConditional>(gaussianMixture),
|
||||||
discreteFactor};
|
discreteFactor};
|
||||||
|
|
|
||||||
|
|
@ -89,8 +89,8 @@ TEST(GaussianMixtureFactor, Sum) {
|
||||||
mode[m1.first] = 1;
|
mode[m1.first] = 1;
|
||||||
mode[m2.first] = 2;
|
mode[m2.first] = 2;
|
||||||
auto actual = sum(mode);
|
auto actual = sum(mode);
|
||||||
EXPECT(actual.graph.at(0) == f11);
|
EXPECT(actual.at(0) == f11);
|
||||||
EXPECT(actual.graph.at(1) == f22);
|
EXPECT(actual.at(1) == f22);
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST(GaussianMixtureFactor, Printing) {
|
TEST(GaussianMixtureFactor, Printing) {
|
||||||
|
|
|
||||||
|
|
@ -640,9 +640,8 @@ TEST(HybridGaussianFactorGraph, assembleGraphTree) {
|
||||||
// Expected decision tree with two factor graphs:
|
// Expected decision tree with two factor graphs:
|
||||||
// f(x0;mode=0)P(x0) and f(x0;mode=1)P(x0)
|
// f(x0;mode=0)P(x0) and f(x0;mode=1)P(x0)
|
||||||
GaussianFactorGraphTree expected{
|
GaussianFactorGraphTree expected{
|
||||||
M(0),
|
M(0), GaussianFactorGraph(std::vector<GF>{(*mixture)(d0), prior}),
|
||||||
{GaussianFactorGraph(std::vector<GF>{(*mixture)(d0), prior}), 0.0},
|
GaussianFactorGraph(std::vector<GF>{(*mixture)(d1), prior})};
|
||||||
{GaussianFactorGraph(std::vector<GF>{(*mixture)(d1), prior}), 0.0}};
|
|
||||||
|
|
||||||
EXPECT(assert_equal(expected(d0), actual(d0), 1e-5));
|
EXPECT(assert_equal(expected(d0), actual(d0), 1e-5));
|
||||||
EXPECT(assert_equal(expected(d1), actual(d1), 1e-5));
|
EXPECT(assert_equal(expected(d1), actual(d1), 1e-5));
|
||||||
|
|
@ -700,7 +699,6 @@ TEST(HybridGaussianFactorGraph, EliminateTiny1) {
|
||||||
const VectorValues measurements{{Z(0), Vector1(5.0)}};
|
const VectorValues measurements{{Z(0), Vector1(5.0)}};
|
||||||
auto bn = tiny::createHybridBayesNet(num_measurements);
|
auto bn = tiny::createHybridBayesNet(num_measurements);
|
||||||
auto fg = bn.toFactorGraph(measurements);
|
auto fg = bn.toFactorGraph(measurements);
|
||||||
GTSAM_PRINT(bn);
|
|
||||||
EXPECT_LONGS_EQUAL(3, fg.size());
|
EXPECT_LONGS_EQUAL(3, fg.size());
|
||||||
|
|
||||||
EXPECT(ratioTest(bn, measurements, fg));
|
EXPECT(ratioTest(bn, measurements, fg));
|
||||||
|
|
@ -724,7 +722,6 @@ TEST(HybridGaussianFactorGraph, EliminateTiny1) {
|
||||||
// Test elimination
|
// Test elimination
|
||||||
const auto posterior = fg.eliminateSequential();
|
const auto posterior = fg.eliminateSequential();
|
||||||
EXPECT(assert_equal(expectedBayesNet, *posterior, 0.01));
|
EXPECT(assert_equal(expectedBayesNet, *posterior, 0.01));
|
||||||
GTSAM_PRINT(*posterior);
|
|
||||||
|
|
||||||
EXPECT(ratioTest(bn, measurements, *posterior));
|
EXPECT(ratioTest(bn, measurements, *posterior));
|
||||||
}
|
}
|
||||||
|
|
@ -847,7 +844,6 @@ TEST(HybridGaussianFactorGraph, EliminateSwitchingNetwork) {
|
||||||
|
|
||||||
// Do elimination:
|
// Do elimination:
|
||||||
const HybridBayesNet::shared_ptr posterior = fg.eliminateSequential(ordering);
|
const HybridBayesNet::shared_ptr posterior = fg.eliminateSequential(ordering);
|
||||||
// GTSAM_PRINT(*posterior);
|
|
||||||
|
|
||||||
// Test resulting posterior Bayes net has correct size:
|
// Test resulting posterior Bayes net has correct size:
|
||||||
EXPECT_LONGS_EQUAL(8, posterior->size());
|
EXPECT_LONGS_EQUAL(8, posterior->size());
|
||||||
|
|
|
||||||
|
|
@ -502,7 +502,8 @@ factor 0:
|
||||||
]
|
]
|
||||||
b = [ -10 ]
|
b = [ -10 ]
|
||||||
No noise model
|
No noise model
|
||||||
factor 1: Hybrid [x0 x1; m0]{
|
factor 1:
|
||||||
|
Hybrid [x0 x1; m0]{
|
||||||
Choice(m0)
|
Choice(m0)
|
||||||
0 Leaf :
|
0 Leaf :
|
||||||
A[x0] = [
|
A[x0] = [
|
||||||
|
|
@ -525,7 +526,8 @@ factor 1: Hybrid [x0 x1; m0]{
|
||||||
No noise model
|
No noise model
|
||||||
|
|
||||||
}
|
}
|
||||||
factor 2: Hybrid [x1 x2; m1]{
|
factor 2:
|
||||||
|
Hybrid [x1 x2; m1]{
|
||||||
Choice(m1)
|
Choice(m1)
|
||||||
0 Leaf :
|
0 Leaf :
|
||||||
A[x1] = [
|
A[x1] = [
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue