Switch to using HybridGaussianProductFactor
parent
8b3dfd85e7
commit
ed9a216365
|
|
@ -32,9 +32,6 @@ namespace gtsam {
|
|||
|
||||
class HybridValues;
|
||||
|
||||
/// Alias for DecisionTree of GaussianFactorGraphs
|
||||
using GaussianFactorGraphTree = DecisionTree<Key, GaussianFactorGraph>;
|
||||
|
||||
KeyVector CollectKeys(const KeyVector &continuousKeys,
|
||||
const DiscreteKeys &discreteKeys);
|
||||
KeyVector CollectKeys(const KeyVector &keys1, const KeyVector &keys2);
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@
|
|||
#include <gtsam/hybrid/HybridGaussianConditional.h>
|
||||
#include <gtsam/hybrid/HybridGaussianFactor.h>
|
||||
#include <gtsam/hybrid/HybridValues.h>
|
||||
#include <gtsam/hybrid/HybridGaussianProductFactor.h>
|
||||
#include <gtsam/inference/Conditional-inst.h>
|
||||
#include <gtsam/linear/GaussianBayesNet.h>
|
||||
#include <gtsam/linear/GaussianFactorGraph.h>
|
||||
|
|
@ -136,9 +137,9 @@ HybridGaussianConditional::conditionals() const {
|
|||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
GaussianFactorGraphTree HybridGaussianConditional::asGaussianFactorGraphTree()
|
||||
HybridGaussianProductFactor HybridGaussianConditional::asProductFactor()
|
||||
const {
|
||||
auto wrap = [this](const GaussianConditional::shared_ptr &gc) {
|
||||
auto wrap = [this](const std::shared_ptr<GaussianConditional> &gc) {
|
||||
// First check if conditional has not been pruned
|
||||
if (gc) {
|
||||
const double Cgm_Kgcm = gc->negLogConstant() - this->negLogConstant_;
|
||||
|
|
@ -155,7 +156,7 @@ GaussianFactorGraphTree HybridGaussianConditional::asGaussianFactorGraphTree()
|
|||
}
|
||||
return GaussianFactorGraph{gc};
|
||||
};
|
||||
return {conditionals_, wrap};
|
||||
return {{conditionals_, wrap}};
|
||||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@
|
|||
#include <gtsam/discrete/DecisionTree.h>
|
||||
#include <gtsam/discrete/DecisionTreeFactor.h>
|
||||
#include <gtsam/discrete/DiscreteKey.h>
|
||||
#include <gtsam/hybrid/HybridGaussianProductFactor.h>
|
||||
#include <gtsam/hybrid/HybridFactor.h>
|
||||
#include <gtsam/hybrid/HybridGaussianFactor.h>
|
||||
#include <gtsam/inference/Conditional.h>
|
||||
|
|
@ -221,6 +222,9 @@ class GTSAM_EXPORT HybridGaussianConditional
|
|||
HybridGaussianConditional::shared_ptr prune(
|
||||
const DecisionTreeFactor &discreteProbs) const;
|
||||
|
||||
/// Convert to a DecisionTree of Gaussian factor graphs.
|
||||
HybridGaussianProductFactor asProductFactor() const override;
|
||||
|
||||
/// @}
|
||||
|
||||
private:
|
||||
|
|
@ -231,9 +235,6 @@ class GTSAM_EXPORT HybridGaussianConditional
|
|||
HybridGaussianConditional(const DiscreteKeys &discreteParents,
|
||||
const Helper &helper);
|
||||
|
||||
/// Convert to a DecisionTree of Gaussian factor graphs.
|
||||
GaussianFactorGraphTree asGaussianFactorGraphTree() const;
|
||||
|
||||
/// Check whether `given` has values for all frontal keys.
|
||||
bool allFrontalsGiven(const VectorValues &given) const;
|
||||
|
||||
|
|
|
|||
|
|
@ -195,26 +195,6 @@ HybridGaussianFactor::sharedFactor HybridGaussianFactor::operator()(
|
|||
return factors_(assignment).first;
|
||||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
GaussianFactorGraphTree HybridGaussianFactor::add(
|
||||
const GaussianFactorGraphTree &sum) const {
|
||||
using Y = GaussianFactorGraph;
|
||||
auto add = [](const Y &graph1, const Y &graph2) {
|
||||
auto result = graph1;
|
||||
result.push_back(graph2);
|
||||
return result;
|
||||
};
|
||||
const auto tree = asGaussianFactorGraphTree();
|
||||
return sum.empty() ? tree : sum.apply(tree, add);
|
||||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
GaussianFactorGraphTree HybridGaussianFactor::asGaussianFactorGraphTree()
|
||||
const {
|
||||
auto wrap = [](const auto &pair) { return GaussianFactorGraph{pair.first}; };
|
||||
return {factors_, wrap};
|
||||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
HybridGaussianProductFactor HybridGaussianFactor::asProductFactor() const {
|
||||
return {{factors_,
|
||||
|
|
|
|||
|
|
@ -130,16 +130,6 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
|
|||
/// Get factor at a given discrete assignment.
|
||||
sharedFactor operator()(const DiscreteValues &assignment) const;
|
||||
|
||||
/**
|
||||
* @brief Combine the Gaussian Factor Graphs in `sum` and `this` while
|
||||
* maintaining the original tree structure.
|
||||
*
|
||||
* @param sum Decision Tree of Gaussian Factor Graphs indexed by the
|
||||
* variables.
|
||||
* @return Sum
|
||||
*/
|
||||
GaussianFactorGraphTree add(const GaussianFactorGraphTree &sum) const;
|
||||
|
||||
/**
|
||||
* @brief Compute error of the HybridGaussianFactor as a tree.
|
||||
*
|
||||
|
|
@ -158,14 +148,6 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
|
|||
|
||||
/// Getter for GaussianFactor decision tree
|
||||
const FactorValuePairs &factors() const { return factors_; }
|
||||
|
||||
/// Add HybridNonlinearFactor to a Sum, syntactic sugar.
|
||||
friend GaussianFactorGraphTree &operator+=(
|
||||
GaussianFactorGraphTree &sum, const HybridGaussianFactor &factor) {
|
||||
sum = factor.add(sum);
|
||||
return sum;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Helper function to return factors and functional to create a
|
||||
* DecisionTree of Gaussian Factor Graphs.
|
||||
|
|
@ -175,16 +157,6 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
|
|||
virtual HybridGaussianProductFactor asProductFactor() const;
|
||||
/// @}
|
||||
|
||||
protected:
|
||||
/**
|
||||
* @brief Helper function to return factors and functional to create a
|
||||
* DecisionTree of Gaussian Factor Graphs.
|
||||
*
|
||||
* @return GaussianFactorGraphTree
|
||||
*/
|
||||
GaussianFactorGraphTree asGaussianFactorGraphTree() const;
|
||||
|
||||
|
||||
private:
|
||||
/**
|
||||
* @brief Helper function to augment the [A|b] matrices in the factor
|
||||
|
|
|
|||
|
|
@ -127,43 +127,28 @@ void HybridGaussianFactorGraph::printErrors(
|
|||
std::cout.flush();
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
static GaussianFactorGraphTree addGaussian(
|
||||
const GaussianFactorGraphTree &gfgTree,
|
||||
const GaussianFactor::shared_ptr &factor) {
|
||||
// If the decision tree is not initialized, then initialize it.
|
||||
if (gfgTree.empty()) {
|
||||
GaussianFactorGraph result{factor};
|
||||
return GaussianFactorGraphTree(result);
|
||||
} else {
|
||||
auto add = [&factor](const GaussianFactorGraph &graph) {
|
||||
auto result = graph;
|
||||
result.push_back(factor);
|
||||
return result;
|
||||
};
|
||||
return gfgTree.apply(add);
|
||||
}
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
// TODO(dellaert): it's probably more efficient to first collect the discrete
|
||||
// keys, and then loop over all assignments to populate a vector.
|
||||
GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
|
||||
GaussianFactorGraphTree result;
|
||||
HybridGaussianProductFactor
|
||||
HybridGaussianFactorGraph::collectProductFactor() const {
|
||||
HybridGaussianProductFactor result;
|
||||
|
||||
for (auto &f : factors_) {
|
||||
// TODO(dellaert): just use a virtual method defined in HybridFactor.
|
||||
// TODO(dellaert): can we make this cleaner and less error-prone?
|
||||
if (auto gf = dynamic_pointer_cast<GaussianFactor>(f)) {
|
||||
result = addGaussian(result, gf);
|
||||
result += gf;
|
||||
} else if (auto gc = dynamic_pointer_cast<GaussianConditional>(f)) {
|
||||
result += gc;
|
||||
} else if (auto gmf = dynamic_pointer_cast<HybridGaussianFactor>(f)) {
|
||||
result = gmf->add(result);
|
||||
result += *gmf;
|
||||
} else if (auto gm = dynamic_pointer_cast<HybridGaussianConditional>(f)) {
|
||||
result = gm->add(result);
|
||||
result += *gm; // handled above already?
|
||||
} else if (auto hc = dynamic_pointer_cast<HybridConditional>(f)) {
|
||||
if (auto gm = hc->asHybrid()) {
|
||||
result = gm->add(result);
|
||||
result += *gm;
|
||||
} else if (auto g = hc->asGaussian()) {
|
||||
result = addGaussian(result, g);
|
||||
result += g;
|
||||
} else {
|
||||
// Has to be discrete.
|
||||
// TODO(dellaert): in C++20, we can use std::visit.
|
||||
|
|
@ -176,7 +161,7 @@ GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
|
|||
} else {
|
||||
// TODO(dellaert): there was an unattributed comment here: We need to
|
||||
// handle the case where the object is actually an BayesTreeOrphanWrapper!
|
||||
throwRuntimeError("gtsam::assembleGraphTree", f);
|
||||
throwRuntimeError("gtsam::collectProductFactor", f);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -269,21 +254,6 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
|
|||
return {std::make_shared<HybridConditional>(result.first), result.second};
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
// If any GaussianFactorGraph in the decision tree contains a nullptr, convert
|
||||
// that leaf to an empty GaussianFactorGraph. Needed since the DecisionTree will
|
||||
// otherwise create a GFG with a single (null) factor,
|
||||
// which doesn't register as null.
|
||||
GaussianFactorGraphTree removeEmpty(const GaussianFactorGraphTree &sum) {
|
||||
auto emptyGaussian = [](const GaussianFactorGraph &graph) {
|
||||
bool hasNull =
|
||||
std::any_of(graph.begin(), graph.end(),
|
||||
[](const GaussianFactor::shared_ptr &ptr) { return !ptr; });
|
||||
return hasNull ? GaussianFactorGraph() : graph;
|
||||
};
|
||||
return GaussianFactorGraphTree(sum, emptyGaussian);
|
||||
}
|
||||
|
||||
/* ************************************************************************ */
|
||||
using Result = std::pair<std::shared_ptr<GaussianConditional>,
|
||||
HybridGaussianFactor::sharedFactor>;
|
||||
|
|
@ -334,7 +304,7 @@ static std::shared_ptr<Factor> createHybridGaussianFactor(
|
|||
// and negative since we want log(k)
|
||||
hf->constantTerm() += -2.0 * conditional->negLogConstant();
|
||||
}
|
||||
return {factor, 0.0};
|
||||
return {factor, conditional->negLogConstant()};
|
||||
};
|
||||
DecisionTree<Key, GaussianFactorValuePair> newFactors(eliminationResults,
|
||||
correct);
|
||||
|
|
@ -359,12 +329,12 @@ HybridGaussianFactorGraph::eliminate(const Ordering &keys) const {
|
|||
|
||||
// Collect all the factors to create a set of Gaussian factor graphs in a
|
||||
// decision tree indexed by all discrete keys involved.
|
||||
GaussianFactorGraphTree factorGraphTree = assembleGraphTree();
|
||||
HybridGaussianProductFactor productFactor = collectProductFactor();
|
||||
|
||||
// Convert factor graphs with a nullptr to an empty factor graph.
|
||||
// This is done after assembly since it is non-trivial to keep track of which
|
||||
// FG has a nullptr as we're looping over the factors.
|
||||
factorGraphTree = removeEmpty(factorGraphTree);
|
||||
auto prunedProductFactor = productFactor.removeEmpty();
|
||||
|
||||
// This is the elimination method on the leaf nodes
|
||||
bool someContinuousLeft = false;
|
||||
|
|
@ -383,7 +353,7 @@ HybridGaussianFactorGraph::eliminate(const Ordering &keys) const {
|
|||
};
|
||||
|
||||
// Perform elimination!
|
||||
DecisionTree<Key, Result> eliminationResults(factorGraphTree, eliminate);
|
||||
DecisionTree<Key, Result> eliminationResults(prunedProductFactor, eliminate);
|
||||
|
||||
// If there are no more continuous parents we create a DiscreteFactor with the
|
||||
// error for each discrete choice. Otherwise, create a HybridGaussianFactor
|
||||
|
|
|
|||
|
|
@ -218,7 +218,7 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
|
|||
* one for A and one for B. The leaves of the tree will be the Gaussian
|
||||
* factors that have only continuous keys.
|
||||
*/
|
||||
GaussianFactorGraphTree assembleGraphTree() const;
|
||||
HybridGaussianProductFactor collectProductFactor() const;
|
||||
|
||||
/**
|
||||
* @brief Eliminate the given continuous keys.
|
||||
|
|
|
|||
|
|
@ -82,40 +82,25 @@ TEST(HybridGaussianFactor, ConstructorVariants) {
|
|||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// "Add" two hybrid factors together.
|
||||
TEST(HybridGaussianFactor, Sum) {
|
||||
TEST(HybridGaussianFactor, Keys) {
|
||||
using namespace test_constructor;
|
||||
DiscreteKey m2(2, 3);
|
||||
|
||||
auto A3 = Matrix::Zero(2, 3);
|
||||
auto f20 = std::make_shared<JacobianFactor>(X(1), A1, X(3), A3, b);
|
||||
auto f21 = std::make_shared<JacobianFactor>(X(1), A1, X(3), A3, b);
|
||||
auto f22 = std::make_shared<JacobianFactor>(X(1), A1, X(3), A3, b);
|
||||
|
||||
// TODO(Frank): why specify keys at all? And: keys in factor should be *all*
|
||||
// keys, deviating from Kevin's scheme. Should we index DT on DiscreteKey?
|
||||
// Design review!
|
||||
HybridGaussianFactor hybridFactorA(m1, {f10, f11});
|
||||
HybridGaussianFactor hybridFactorB(m2, {f20, f21, f22});
|
||||
|
||||
// Check the number of keys matches what we expect
|
||||
EXPECT_LONGS_EQUAL(3, hybridFactorA.keys().size());
|
||||
EXPECT_LONGS_EQUAL(2, hybridFactorA.continuousKeys().size());
|
||||
EXPECT_LONGS_EQUAL(1, hybridFactorA.discreteKeys().size());
|
||||
|
||||
// Create sum of two hybrid factors: it will be a decision tree now on both
|
||||
// discrete variables m1 and m2:
|
||||
GaussianFactorGraphTree sum;
|
||||
sum += hybridFactorA;
|
||||
sum += hybridFactorB;
|
||||
DiscreteKey m2(2, 3);
|
||||
auto A3 = Matrix::Zero(2, 3);
|
||||
auto f20 = std::make_shared<JacobianFactor>(X(1), A1, X(3), A3, b);
|
||||
auto f21 = std::make_shared<JacobianFactor>(X(1), A1, X(3), A3, b);
|
||||
auto f22 = std::make_shared<JacobianFactor>(X(1), A1, X(3), A3, b);
|
||||
HybridGaussianFactor hybridFactorB(m2, {f20, f21, f22});
|
||||
|
||||
// Let's check that this worked:
|
||||
Assignment<Key> mode;
|
||||
mode[m1.first] = 1;
|
||||
mode[m2.first] = 2;
|
||||
auto actual = sum(mode);
|
||||
EXPECT(actual.at(0) == f11);
|
||||
EXPECT(actual.at(1) == f22);
|
||||
// Check the number of keys matches what we expect
|
||||
EXPECT_LONGS_EQUAL(3, hybridFactorB.keys().size());
|
||||
EXPECT_LONGS_EQUAL(2, hybridFactorB.continuousKeys().size());
|
||||
EXPECT_LONGS_EQUAL(1, hybridFactorB.discreteKeys().size());
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@
|
|||
#include <gtsam/hybrid/HybridGaussianConditional.h>
|
||||
#include <gtsam/hybrid/HybridGaussianFactor.h>
|
||||
#include <gtsam/hybrid/HybridGaussianFactorGraph.h>
|
||||
#include <gtsam/hybrid/HybridGaussianProductFactor.h>
|
||||
#include <gtsam/hybrid/HybridGaussianISAM.h>
|
||||
#include <gtsam/hybrid/HybridValues.h>
|
||||
#include <gtsam/inference/BayesNet.h>
|
||||
|
|
@ -674,14 +675,14 @@ TEST(HybridGaussianFactorGraph, IncrementalErrorTree) {
|
|||
/* ****************************************************************************/
|
||||
// Check that assembleGraphTree assembles Gaussian factor graphs for each
|
||||
// assignment.
|
||||
TEST(HybridGaussianFactorGraph, assembleGraphTree) {
|
||||
TEST(HybridGaussianFactorGraph, collectProductFactor) {
|
||||
const int num_measurements = 1;
|
||||
auto fg = tiny::createHybridGaussianFactorGraph(
|
||||
num_measurements, VectorValues{{Z(0), Vector1(5.0)}});
|
||||
EXPECT_LONGS_EQUAL(3, fg.size());
|
||||
|
||||
// Assemble graph tree:
|
||||
auto actual = fg.assembleGraphTree();
|
||||
auto actual = fg.collectProductFactor();
|
||||
|
||||
// Create expected decision tree with two factor graphs:
|
||||
|
||||
|
|
@ -701,9 +702,9 @@ TEST(HybridGaussianFactorGraph, assembleGraphTree) {
|
|||
|
||||
// Expected decision tree with two factor graphs:
|
||||
// f(x0;mode=0)P(x0) and f(x0;mode=1)P(x0)
|
||||
GaussianFactorGraphTree expected{
|
||||
M(0), GaussianFactorGraph(std::vector<GF>{(*hybrid)(d0), prior}),
|
||||
GaussianFactorGraph(std::vector<GF>{(*hybrid)(d1), prior})};
|
||||
HybridGaussianProductFactor expected{
|
||||
{M(0), GaussianFactorGraph(std::vector<GF>{(*hybrid)(d0), prior}),
|
||||
GaussianFactorGraph(std::vector<GF>{(*hybrid)(d1), prior})}};
|
||||
|
||||
EXPECT(assert_equal(expected(d0), actual(d0), 1e-5));
|
||||
EXPECT(assert_equal(expected(d1), actual(d1), 1e-5));
|
||||
|
|
|
|||
Loading…
Reference in New Issue