Switch to using HybridGaussianProductFactor

release/4.3a0
Frank Dellaert 2024-10-05 07:18:19 +09:00
parent 8b3dfd85e7
commit ed9a216365
9 changed files with 42 additions and 135 deletions

View File

@ -32,9 +32,6 @@ namespace gtsam {
class HybridValues;
/// Alias for DecisionTree of GaussianFactorGraphs
using GaussianFactorGraphTree = DecisionTree<Key, GaussianFactorGraph>;
KeyVector CollectKeys(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys);
KeyVector CollectKeys(const KeyVector &keys1, const KeyVector &keys2);

View File

@ -23,6 +23,7 @@
#include <gtsam/hybrid/HybridGaussianConditional.h>
#include <gtsam/hybrid/HybridGaussianFactor.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/hybrid/HybridGaussianProductFactor.h>
#include <gtsam/inference/Conditional-inst.h>
#include <gtsam/linear/GaussianBayesNet.h>
#include <gtsam/linear/GaussianFactorGraph.h>
@ -136,9 +137,9 @@ HybridGaussianConditional::conditionals() const {
}
/* *******************************************************************************/
GaussianFactorGraphTree HybridGaussianConditional::asGaussianFactorGraphTree()
HybridGaussianProductFactor HybridGaussianConditional::asProductFactor()
const {
auto wrap = [this](const GaussianConditional::shared_ptr &gc) {
auto wrap = [this](const std::shared_ptr<GaussianConditional> &gc) {
// First check if conditional has not been pruned
if (gc) {
const double Cgm_Kgcm = gc->negLogConstant() - this->negLogConstant_;
@ -155,7 +156,7 @@ GaussianFactorGraphTree HybridGaussianConditional::asGaussianFactorGraphTree()
}
return GaussianFactorGraph{gc};
};
return {conditionals_, wrap};
return {{conditionals_, wrap}};
}
/* *******************************************************************************/

View File

@ -24,6 +24,7 @@
#include <gtsam/discrete/DecisionTree.h>
#include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/hybrid/HybridGaussianProductFactor.h>
#include <gtsam/hybrid/HybridFactor.h>
#include <gtsam/hybrid/HybridGaussianFactor.h>
#include <gtsam/inference/Conditional.h>
@ -221,6 +222,9 @@ class GTSAM_EXPORT HybridGaussianConditional
HybridGaussianConditional::shared_ptr prune(
const DecisionTreeFactor &discreteProbs) const;
/// Convert to a DecisionTree of Gaussian factor graphs.
HybridGaussianProductFactor asProductFactor() const override;
/// @}
private:
@ -231,9 +235,6 @@ class GTSAM_EXPORT HybridGaussianConditional
HybridGaussianConditional(const DiscreteKeys &discreteParents,
const Helper &helper);
/// Convert to a DecisionTree of Gaussian factor graphs.
GaussianFactorGraphTree asGaussianFactorGraphTree() const;
/// Check whether `given` has values for all frontal keys.
bool allFrontalsGiven(const VectorValues &given) const;

View File

@ -195,26 +195,6 @@ HybridGaussianFactor::sharedFactor HybridGaussianFactor::operator()(
return factors_(assignment).first;
}
/* *******************************************************************************/
GaussianFactorGraphTree HybridGaussianFactor::add(
const GaussianFactorGraphTree &sum) const {
using Y = GaussianFactorGraph;
auto add = [](const Y &graph1, const Y &graph2) {
auto result = graph1;
result.push_back(graph2);
return result;
};
const auto tree = asGaussianFactorGraphTree();
return sum.empty() ? tree : sum.apply(tree, add);
}
/* *******************************************************************************/
GaussianFactorGraphTree HybridGaussianFactor::asGaussianFactorGraphTree()
const {
auto wrap = [](const auto &pair) { return GaussianFactorGraph{pair.first}; };
return {factors_, wrap};
}
/* *******************************************************************************/
HybridGaussianProductFactor HybridGaussianFactor::asProductFactor() const {
return {{factors_,

View File

@ -130,16 +130,6 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
/// Get factor at a given discrete assignment.
sharedFactor operator()(const DiscreteValues &assignment) const;
/**
* @brief Combine the Gaussian Factor Graphs in `sum` and `this` while
* maintaining the original tree structure.
*
* @param sum Decision Tree of Gaussian Factor Graphs indexed by the
* variables.
* @return Sum
*/
GaussianFactorGraphTree add(const GaussianFactorGraphTree &sum) const;
/**
* @brief Compute error of the HybridGaussianFactor as a tree.
*
@ -158,14 +148,6 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
/// Getter for GaussianFactor decision tree
const FactorValuePairs &factors() const { return factors_; }
/// Add HybridNonlinearFactor to a Sum, syntactic sugar.
friend GaussianFactorGraphTree &operator+=(
GaussianFactorGraphTree &sum, const HybridGaussianFactor &factor) {
sum = factor.add(sum);
return sum;
}
/**
* @brief Helper function to return factors and functional to create a
* DecisionTree of Gaussian Factor Graphs.
@ -175,16 +157,6 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
virtual HybridGaussianProductFactor asProductFactor() const;
/// @}
protected:
/**
* @brief Helper function to return factors and functional to create a
* DecisionTree of Gaussian Factor Graphs.
*
* @return GaussianFactorGraphTree
*/
GaussianFactorGraphTree asGaussianFactorGraphTree() const;
private:
/**
* @brief Helper function to augment the [A|b] matrices in the factor

View File

@ -127,43 +127,28 @@ void HybridGaussianFactorGraph::printErrors(
std::cout.flush();
}
/* ************************************************************************ */
static GaussianFactorGraphTree addGaussian(
const GaussianFactorGraphTree &gfgTree,
const GaussianFactor::shared_ptr &factor) {
// If the decision tree is not initialized, then initialize it.
if (gfgTree.empty()) {
GaussianFactorGraph result{factor};
return GaussianFactorGraphTree(result);
} else {
auto add = [&factor](const GaussianFactorGraph &graph) {
auto result = graph;
result.push_back(factor);
return result;
};
return gfgTree.apply(add);
}
}
/* ************************************************************************ */
// TODO(dellaert): it's probably more efficient to first collect the discrete
// keys, and then loop over all assignments to populate a vector.
GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
GaussianFactorGraphTree result;
HybridGaussianProductFactor
HybridGaussianFactorGraph::collectProductFactor() const {
HybridGaussianProductFactor result;
for (auto &f : factors_) {
// TODO(dellaert): just use a virtual method defined in HybridFactor.
// TODO(dellaert): can we make this cleaner and less error-prone?
if (auto gf = dynamic_pointer_cast<GaussianFactor>(f)) {
result = addGaussian(result, gf);
result += gf;
} else if (auto gc = dynamic_pointer_cast<GaussianConditional>(f)) {
result += gc;
} else if (auto gmf = dynamic_pointer_cast<HybridGaussianFactor>(f)) {
result = gmf->add(result);
result += *gmf;
} else if (auto gm = dynamic_pointer_cast<HybridGaussianConditional>(f)) {
result = gm->add(result);
result += *gm; // handled above already?
} else if (auto hc = dynamic_pointer_cast<HybridConditional>(f)) {
if (auto gm = hc->asHybrid()) {
result = gm->add(result);
result += *gm;
} else if (auto g = hc->asGaussian()) {
result = addGaussian(result, g);
result += g;
} else {
// Has to be discrete.
// TODO(dellaert): in C++20, we can use std::visit.
@ -176,7 +161,7 @@ GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
} else {
// TODO(dellaert): there was an unattributed comment here: We need to
// handle the case where the object is actually an BayesTreeOrphanWrapper!
throwRuntimeError("gtsam::assembleGraphTree", f);
throwRuntimeError("gtsam::collectProductFactor", f);
}
}
@ -269,21 +254,6 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
return {std::make_shared<HybridConditional>(result.first), result.second};
}
/* ************************************************************************ */
// If any GaussianFactorGraph in the decision tree contains a nullptr, convert
// that leaf to an empty GaussianFactorGraph. Needed since the DecisionTree will
// otherwise create a GFG with a single (null) factor,
// which doesn't register as null.
GaussianFactorGraphTree removeEmpty(const GaussianFactorGraphTree &sum) {
auto emptyGaussian = [](const GaussianFactorGraph &graph) {
bool hasNull =
std::any_of(graph.begin(), graph.end(),
[](const GaussianFactor::shared_ptr &ptr) { return !ptr; });
return hasNull ? GaussianFactorGraph() : graph;
};
return GaussianFactorGraphTree(sum, emptyGaussian);
}
/* ************************************************************************ */
using Result = std::pair<std::shared_ptr<GaussianConditional>,
HybridGaussianFactor::sharedFactor>;
@ -334,7 +304,7 @@ static std::shared_ptr<Factor> createHybridGaussianFactor(
// and negative since we want log(k)
hf->constantTerm() += -2.0 * conditional->negLogConstant();
}
return {factor, 0.0};
return {factor, conditional->negLogConstant()};
};
DecisionTree<Key, GaussianFactorValuePair> newFactors(eliminationResults,
correct);
@ -359,12 +329,12 @@ HybridGaussianFactorGraph::eliminate(const Ordering &keys) const {
// Collect all the factors to create a set of Gaussian factor graphs in a
// decision tree indexed by all discrete keys involved.
GaussianFactorGraphTree factorGraphTree = assembleGraphTree();
HybridGaussianProductFactor productFactor = collectProductFactor();
// Convert factor graphs with a nullptr to an empty factor graph.
// This is done after assembly since it is non-trivial to keep track of which
// FG has a nullptr as we're looping over the factors.
factorGraphTree = removeEmpty(factorGraphTree);
auto prunedProductFactor = productFactor.removeEmpty();
// This is the elimination method on the leaf nodes
bool someContinuousLeft = false;
@ -383,7 +353,7 @@ HybridGaussianFactorGraph::eliminate(const Ordering &keys) const {
};
// Perform elimination!
DecisionTree<Key, Result> eliminationResults(factorGraphTree, eliminate);
DecisionTree<Key, Result> eliminationResults(prunedProductFactor, eliminate);
// If there are no more continuous parents we create a DiscreteFactor with the
// error for each discrete choice. Otherwise, create a HybridGaussianFactor

View File

@ -218,7 +218,7 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
* one for A and one for B. The leaves of the tree will be the Gaussian
* factors that have only continuous keys.
*/
GaussianFactorGraphTree assembleGraphTree() const;
HybridGaussianProductFactor collectProductFactor() const;
/**
* @brief Eliminate the given continuous keys.

View File

@ -82,40 +82,25 @@ TEST(HybridGaussianFactor, ConstructorVariants) {
}
/* ************************************************************************* */
// "Add" two hybrid factors together.
TEST(HybridGaussianFactor, Sum) {
TEST(HybridGaussianFactor, Keys) {
using namespace test_constructor;
DiscreteKey m2(2, 3);
auto A3 = Matrix::Zero(2, 3);
auto f20 = std::make_shared<JacobianFactor>(X(1), A1, X(3), A3, b);
auto f21 = std::make_shared<JacobianFactor>(X(1), A1, X(3), A3, b);
auto f22 = std::make_shared<JacobianFactor>(X(1), A1, X(3), A3, b);
// TODO(Frank): why specify keys at all? And: keys in factor should be *all*
// keys, deviating from Kevin's scheme. Should we index DT on DiscreteKey?
// Design review!
HybridGaussianFactor hybridFactorA(m1, {f10, f11});
HybridGaussianFactor hybridFactorB(m2, {f20, f21, f22});
// Check the number of keys matches what we expect
EXPECT_LONGS_EQUAL(3, hybridFactorA.keys().size());
EXPECT_LONGS_EQUAL(2, hybridFactorA.continuousKeys().size());
EXPECT_LONGS_EQUAL(1, hybridFactorA.discreteKeys().size());
// Create sum of two hybrid factors: it will be a decision tree now on both
// discrete variables m1 and m2:
GaussianFactorGraphTree sum;
sum += hybridFactorA;
sum += hybridFactorB;
DiscreteKey m2(2, 3);
auto A3 = Matrix::Zero(2, 3);
auto f20 = std::make_shared<JacobianFactor>(X(1), A1, X(3), A3, b);
auto f21 = std::make_shared<JacobianFactor>(X(1), A1, X(3), A3, b);
auto f22 = std::make_shared<JacobianFactor>(X(1), A1, X(3), A3, b);
HybridGaussianFactor hybridFactorB(m2, {f20, f21, f22});
// Let's check that this worked:
Assignment<Key> mode;
mode[m1.first] = 1;
mode[m2.first] = 2;
auto actual = sum(mode);
EXPECT(actual.at(0) == f11);
EXPECT(actual.at(1) == f22);
// Check the number of keys matches what we expect
EXPECT_LONGS_EQUAL(3, hybridFactorB.keys().size());
EXPECT_LONGS_EQUAL(2, hybridFactorB.continuousKeys().size());
EXPECT_LONGS_EQUAL(1, hybridFactorB.discreteKeys().size());
}
/* ************************************************************************* */

View File

@ -29,6 +29,7 @@
#include <gtsam/hybrid/HybridGaussianConditional.h>
#include <gtsam/hybrid/HybridGaussianFactor.h>
#include <gtsam/hybrid/HybridGaussianFactorGraph.h>
#include <gtsam/hybrid/HybridGaussianProductFactor.h>
#include <gtsam/hybrid/HybridGaussianISAM.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/inference/BayesNet.h>
@ -674,14 +675,14 @@ TEST(HybridGaussianFactorGraph, IncrementalErrorTree) {
/* ****************************************************************************/
// Check that assembleGraphTree assembles Gaussian factor graphs for each
// assignment.
TEST(HybridGaussianFactorGraph, assembleGraphTree) {
TEST(HybridGaussianFactorGraph, collectProductFactor) {
const int num_measurements = 1;
auto fg = tiny::createHybridGaussianFactorGraph(
num_measurements, VectorValues{{Z(0), Vector1(5.0)}});
EXPECT_LONGS_EQUAL(3, fg.size());
// Assemble graph tree:
auto actual = fg.assembleGraphTree();
auto actual = fg.collectProductFactor();
// Create expected decision tree with two factor graphs:
@ -701,9 +702,9 @@ TEST(HybridGaussianFactorGraph, assembleGraphTree) {
// Expected decision tree with two factor graphs:
// f(x0;mode=0)P(x0) and f(x0;mode=1)P(x0)
GaussianFactorGraphTree expected{
M(0), GaussianFactorGraph(std::vector<GF>{(*hybrid)(d0), prior}),
GaussianFactorGraph(std::vector<GF>{(*hybrid)(d1), prior})};
HybridGaussianProductFactor expected{
{M(0), GaussianFactorGraph(std::vector<GF>{(*hybrid)(d0), prior}),
GaussianFactorGraph(std::vector<GF>{(*hybrid)(d1), prior})}};
EXPECT(assert_equal(expected(d0), actual(d0), 1e-5));
EXPECT(assert_equal(expected(d1), actual(d1), 1e-5));