Switch to using HybridGaussianProductFactor

release/4.3a0
Frank Dellaert 2024-10-05 07:18:19 +09:00
parent 8b3dfd85e7
commit ed9a216365
9 changed files with 42 additions and 135 deletions

View File

@ -32,9 +32,6 @@ namespace gtsam {
class HybridValues; class HybridValues;
/// Alias for DecisionTree of GaussianFactorGraphs
using GaussianFactorGraphTree = DecisionTree<Key, GaussianFactorGraph>;
KeyVector CollectKeys(const KeyVector &continuousKeys, KeyVector CollectKeys(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys); const DiscreteKeys &discreteKeys);
KeyVector CollectKeys(const KeyVector &keys1, const KeyVector &keys2); KeyVector CollectKeys(const KeyVector &keys1, const KeyVector &keys2);

View File

@ -23,6 +23,7 @@
#include <gtsam/hybrid/HybridGaussianConditional.h> #include <gtsam/hybrid/HybridGaussianConditional.h>
#include <gtsam/hybrid/HybridGaussianFactor.h> #include <gtsam/hybrid/HybridGaussianFactor.h>
#include <gtsam/hybrid/HybridValues.h> #include <gtsam/hybrid/HybridValues.h>
#include <gtsam/hybrid/HybridGaussianProductFactor.h>
#include <gtsam/inference/Conditional-inst.h> #include <gtsam/inference/Conditional-inst.h>
#include <gtsam/linear/GaussianBayesNet.h> #include <gtsam/linear/GaussianBayesNet.h>
#include <gtsam/linear/GaussianFactorGraph.h> #include <gtsam/linear/GaussianFactorGraph.h>
@ -136,9 +137,9 @@ HybridGaussianConditional::conditionals() const {
} }
/* *******************************************************************************/ /* *******************************************************************************/
GaussianFactorGraphTree HybridGaussianConditional::asGaussianFactorGraphTree() HybridGaussianProductFactor HybridGaussianConditional::asProductFactor()
const { const {
auto wrap = [this](const GaussianConditional::shared_ptr &gc) { auto wrap = [this](const std::shared_ptr<GaussianConditional> &gc) {
// First check if conditional has not been pruned // First check if conditional has not been pruned
if (gc) { if (gc) {
const double Cgm_Kgcm = gc->negLogConstant() - this->negLogConstant_; const double Cgm_Kgcm = gc->negLogConstant() - this->negLogConstant_;
@ -155,7 +156,7 @@ GaussianFactorGraphTree HybridGaussianConditional::asGaussianFactorGraphTree()
} }
return GaussianFactorGraph{gc}; return GaussianFactorGraph{gc};
}; };
return {conditionals_, wrap}; return {{conditionals_, wrap}};
} }
/* *******************************************************************************/ /* *******************************************************************************/

View File

@ -24,6 +24,7 @@
#include <gtsam/discrete/DecisionTree.h> #include <gtsam/discrete/DecisionTree.h>
#include <gtsam/discrete/DecisionTreeFactor.h> #include <gtsam/discrete/DecisionTreeFactor.h>
#include <gtsam/discrete/DiscreteKey.h> #include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/hybrid/HybridGaussianProductFactor.h>
#include <gtsam/hybrid/HybridFactor.h> #include <gtsam/hybrid/HybridFactor.h>
#include <gtsam/hybrid/HybridGaussianFactor.h> #include <gtsam/hybrid/HybridGaussianFactor.h>
#include <gtsam/inference/Conditional.h> #include <gtsam/inference/Conditional.h>
@ -221,6 +222,9 @@ class GTSAM_EXPORT HybridGaussianConditional
HybridGaussianConditional::shared_ptr prune( HybridGaussianConditional::shared_ptr prune(
const DecisionTreeFactor &discreteProbs) const; const DecisionTreeFactor &discreteProbs) const;
/// Convert to a DecisionTree of Gaussian factor graphs.
HybridGaussianProductFactor asProductFactor() const override;
/// @} /// @}
private: private:
@ -231,9 +235,6 @@ class GTSAM_EXPORT HybridGaussianConditional
HybridGaussianConditional(const DiscreteKeys &discreteParents, HybridGaussianConditional(const DiscreteKeys &discreteParents,
const Helper &helper); const Helper &helper);
/// Convert to a DecisionTree of Gaussian factor graphs.
GaussianFactorGraphTree asGaussianFactorGraphTree() const;
/// Check whether `given` has values for all frontal keys. /// Check whether `given` has values for all frontal keys.
bool allFrontalsGiven(const VectorValues &given) const; bool allFrontalsGiven(const VectorValues &given) const;

View File

@ -195,26 +195,6 @@ HybridGaussianFactor::sharedFactor HybridGaussianFactor::operator()(
return factors_(assignment).first; return factors_(assignment).first;
} }
/* *******************************************************************************/
GaussianFactorGraphTree HybridGaussianFactor::add(
const GaussianFactorGraphTree &sum) const {
using Y = GaussianFactorGraph;
auto add = [](const Y &graph1, const Y &graph2) {
auto result = graph1;
result.push_back(graph2);
return result;
};
const auto tree = asGaussianFactorGraphTree();
return sum.empty() ? tree : sum.apply(tree, add);
}
/* *******************************************************************************/
GaussianFactorGraphTree HybridGaussianFactor::asGaussianFactorGraphTree()
const {
auto wrap = [](const auto &pair) { return GaussianFactorGraph{pair.first}; };
return {factors_, wrap};
}
/* *******************************************************************************/ /* *******************************************************************************/
HybridGaussianProductFactor HybridGaussianFactor::asProductFactor() const { HybridGaussianProductFactor HybridGaussianFactor::asProductFactor() const {
return {{factors_, return {{factors_,

View File

@ -130,16 +130,6 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
/// Get factor at a given discrete assignment. /// Get factor at a given discrete assignment.
sharedFactor operator()(const DiscreteValues &assignment) const; sharedFactor operator()(const DiscreteValues &assignment) const;
/**
* @brief Combine the Gaussian Factor Graphs in `sum` and `this` while
* maintaining the original tree structure.
*
* @param sum Decision Tree of Gaussian Factor Graphs indexed by the
* variables.
* @return Sum
*/
GaussianFactorGraphTree add(const GaussianFactorGraphTree &sum) const;
/** /**
* @brief Compute error of the HybridGaussianFactor as a tree. * @brief Compute error of the HybridGaussianFactor as a tree.
* *
@ -158,14 +148,6 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
/// Getter for GaussianFactor decision tree /// Getter for GaussianFactor decision tree
const FactorValuePairs &factors() const { return factors_; } const FactorValuePairs &factors() const { return factors_; }
/// Add HybridNonlinearFactor to a Sum, syntactic sugar.
friend GaussianFactorGraphTree &operator+=(
GaussianFactorGraphTree &sum, const HybridGaussianFactor &factor) {
sum = factor.add(sum);
return sum;
}
/** /**
* @brief Helper function to return factors and functional to create a * @brief Helper function to return factors and functional to create a
* DecisionTree of Gaussian Factor Graphs. * DecisionTree of Gaussian Factor Graphs.
@ -175,16 +157,6 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
virtual HybridGaussianProductFactor asProductFactor() const; virtual HybridGaussianProductFactor asProductFactor() const;
/// @} /// @}
protected:
/**
* @brief Helper function to return factors and functional to create a
* DecisionTree of Gaussian Factor Graphs.
*
* @return GaussianFactorGraphTree
*/
GaussianFactorGraphTree asGaussianFactorGraphTree() const;
private: private:
/** /**
* @brief Helper function to augment the [A|b] matrices in the factor * @brief Helper function to augment the [A|b] matrices in the factor

View File

@ -127,43 +127,28 @@ void HybridGaussianFactorGraph::printErrors(
std::cout.flush(); std::cout.flush();
} }
/* ************************************************************************ */
static GaussianFactorGraphTree addGaussian(
const GaussianFactorGraphTree &gfgTree,
const GaussianFactor::shared_ptr &factor) {
// If the decision tree is not initialized, then initialize it.
if (gfgTree.empty()) {
GaussianFactorGraph result{factor};
return GaussianFactorGraphTree(result);
} else {
auto add = [&factor](const GaussianFactorGraph &graph) {
auto result = graph;
result.push_back(factor);
return result;
};
return gfgTree.apply(add);
}
}
/* ************************************************************************ */ /* ************************************************************************ */
// TODO(dellaert): it's probably more efficient to first collect the discrete // TODO(dellaert): it's probably more efficient to first collect the discrete
// keys, and then loop over all assignments to populate a vector. // keys, and then loop over all assignments to populate a vector.
GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const { HybridGaussianProductFactor
GaussianFactorGraphTree result; HybridGaussianFactorGraph::collectProductFactor() const {
HybridGaussianProductFactor result;
for (auto &f : factors_) { for (auto &f : factors_) {
// TODO(dellaert): just use a virtual method defined in HybridFactor. // TODO(dellaert): can we make this cleaner and less error-prone?
if (auto gf = dynamic_pointer_cast<GaussianFactor>(f)) { if (auto gf = dynamic_pointer_cast<GaussianFactor>(f)) {
result = addGaussian(result, gf); result += gf;
} else if (auto gc = dynamic_pointer_cast<GaussianConditional>(f)) {
result += gc;
} else if (auto gmf = dynamic_pointer_cast<HybridGaussianFactor>(f)) { } else if (auto gmf = dynamic_pointer_cast<HybridGaussianFactor>(f)) {
result = gmf->add(result); result += *gmf;
} else if (auto gm = dynamic_pointer_cast<HybridGaussianConditional>(f)) { } else if (auto gm = dynamic_pointer_cast<HybridGaussianConditional>(f)) {
result = gm->add(result); result += *gm; // handled above already?
} else if (auto hc = dynamic_pointer_cast<HybridConditional>(f)) { } else if (auto hc = dynamic_pointer_cast<HybridConditional>(f)) {
if (auto gm = hc->asHybrid()) { if (auto gm = hc->asHybrid()) {
result = gm->add(result); result += *gm;
} else if (auto g = hc->asGaussian()) { } else if (auto g = hc->asGaussian()) {
result = addGaussian(result, g); result += g;
} else { } else {
// Has to be discrete. // Has to be discrete.
// TODO(dellaert): in C++20, we can use std::visit. // TODO(dellaert): in C++20, we can use std::visit.
@ -176,7 +161,7 @@ GaussianFactorGraphTree HybridGaussianFactorGraph::assembleGraphTree() const {
} else { } else {
// TODO(dellaert): there was an unattributed comment here: We need to // TODO(dellaert): there was an unattributed comment here: We need to
// handle the case where the object is actually an BayesTreeOrphanWrapper! // handle the case where the object is actually an BayesTreeOrphanWrapper!
throwRuntimeError("gtsam::assembleGraphTree", f); throwRuntimeError("gtsam::collectProductFactor", f);
} }
} }
@ -269,21 +254,6 @@ discreteElimination(const HybridGaussianFactorGraph &factors,
return {std::make_shared<HybridConditional>(result.first), result.second}; return {std::make_shared<HybridConditional>(result.first), result.second};
} }
/* ************************************************************************ */
// If any GaussianFactorGraph in the decision tree contains a nullptr, convert
// that leaf to an empty GaussianFactorGraph. Needed since the DecisionTree will
// otherwise create a GFG with a single (null) factor,
// which doesn't register as null.
GaussianFactorGraphTree removeEmpty(const GaussianFactorGraphTree &sum) {
auto emptyGaussian = [](const GaussianFactorGraph &graph) {
bool hasNull =
std::any_of(graph.begin(), graph.end(),
[](const GaussianFactor::shared_ptr &ptr) { return !ptr; });
return hasNull ? GaussianFactorGraph() : graph;
};
return GaussianFactorGraphTree(sum, emptyGaussian);
}
/* ************************************************************************ */ /* ************************************************************************ */
using Result = std::pair<std::shared_ptr<GaussianConditional>, using Result = std::pair<std::shared_ptr<GaussianConditional>,
HybridGaussianFactor::sharedFactor>; HybridGaussianFactor::sharedFactor>;
@ -334,7 +304,7 @@ static std::shared_ptr<Factor> createHybridGaussianFactor(
// and negative since we want log(k) // and negative since we want log(k)
hf->constantTerm() += -2.0 * conditional->negLogConstant(); hf->constantTerm() += -2.0 * conditional->negLogConstant();
} }
return {factor, 0.0}; return {factor, conditional->negLogConstant()};
}; };
DecisionTree<Key, GaussianFactorValuePair> newFactors(eliminationResults, DecisionTree<Key, GaussianFactorValuePair> newFactors(eliminationResults,
correct); correct);
@ -359,12 +329,12 @@ HybridGaussianFactorGraph::eliminate(const Ordering &keys) const {
// Collect all the factors to create a set of Gaussian factor graphs in a // Collect all the factors to create a set of Gaussian factor graphs in a
// decision tree indexed by all discrete keys involved. // decision tree indexed by all discrete keys involved.
GaussianFactorGraphTree factorGraphTree = assembleGraphTree(); HybridGaussianProductFactor productFactor = collectProductFactor();
// Convert factor graphs with a nullptr to an empty factor graph. // Convert factor graphs with a nullptr to an empty factor graph.
// This is done after assembly since it is non-trivial to keep track of which // This is done after assembly since it is non-trivial to keep track of which
// FG has a nullptr as we're looping over the factors. // FG has a nullptr as we're looping over the factors.
factorGraphTree = removeEmpty(factorGraphTree); auto prunedProductFactor = productFactor.removeEmpty();
// This is the elimination method on the leaf nodes // This is the elimination method on the leaf nodes
bool someContinuousLeft = false; bool someContinuousLeft = false;
@ -383,7 +353,7 @@ HybridGaussianFactorGraph::eliminate(const Ordering &keys) const {
}; };
// Perform elimination! // Perform elimination!
DecisionTree<Key, Result> eliminationResults(factorGraphTree, eliminate); DecisionTree<Key, Result> eliminationResults(prunedProductFactor, eliminate);
// If there are no more continuous parents we create a DiscreteFactor with the // If there are no more continuous parents we create a DiscreteFactor with the
// error for each discrete choice. Otherwise, create a HybridGaussianFactor // error for each discrete choice. Otherwise, create a HybridGaussianFactor

View File

@ -218,7 +218,7 @@ class GTSAM_EXPORT HybridGaussianFactorGraph
* one for A and one for B. The leaves of the tree will be the Gaussian * one for A and one for B. The leaves of the tree will be the Gaussian
* factors that have only continuous keys. * factors that have only continuous keys.
*/ */
GaussianFactorGraphTree assembleGraphTree() const; HybridGaussianProductFactor collectProductFactor() const;
/** /**
* @brief Eliminate the given continuous keys. * @brief Eliminate the given continuous keys.

View File

@ -82,40 +82,25 @@ TEST(HybridGaussianFactor, ConstructorVariants) {
} }
/* ************************************************************************* */ /* ************************************************************************* */
// "Add" two hybrid factors together. TEST(HybridGaussianFactor, Keys) {
TEST(HybridGaussianFactor, Sum) {
using namespace test_constructor; using namespace test_constructor;
DiscreteKey m2(2, 3);
auto A3 = Matrix::Zero(2, 3);
auto f20 = std::make_shared<JacobianFactor>(X(1), A1, X(3), A3, b);
auto f21 = std::make_shared<JacobianFactor>(X(1), A1, X(3), A3, b);
auto f22 = std::make_shared<JacobianFactor>(X(1), A1, X(3), A3, b);
// TODO(Frank): why specify keys at all? And: keys in factor should be *all*
// keys, deviating from Kevin's scheme. Should we index DT on DiscreteKey?
// Design review!
HybridGaussianFactor hybridFactorA(m1, {f10, f11}); HybridGaussianFactor hybridFactorA(m1, {f10, f11});
HybridGaussianFactor hybridFactorB(m2, {f20, f21, f22});
// Check the number of keys matches what we expect // Check the number of keys matches what we expect
EXPECT_LONGS_EQUAL(3, hybridFactorA.keys().size()); EXPECT_LONGS_EQUAL(3, hybridFactorA.keys().size());
EXPECT_LONGS_EQUAL(2, hybridFactorA.continuousKeys().size()); EXPECT_LONGS_EQUAL(2, hybridFactorA.continuousKeys().size());
EXPECT_LONGS_EQUAL(1, hybridFactorA.discreteKeys().size()); EXPECT_LONGS_EQUAL(1, hybridFactorA.discreteKeys().size());
// Create sum of two hybrid factors: it will be a decision tree now on both DiscreteKey m2(2, 3);
// discrete variables m1 and m2: auto A3 = Matrix::Zero(2, 3);
GaussianFactorGraphTree sum; auto f20 = std::make_shared<JacobianFactor>(X(1), A1, X(3), A3, b);
sum += hybridFactorA; auto f21 = std::make_shared<JacobianFactor>(X(1), A1, X(3), A3, b);
sum += hybridFactorB; auto f22 = std::make_shared<JacobianFactor>(X(1), A1, X(3), A3, b);
HybridGaussianFactor hybridFactorB(m2, {f20, f21, f22});
// Let's check that this worked: // Check the number of keys matches what we expect
Assignment<Key> mode; EXPECT_LONGS_EQUAL(3, hybridFactorB.keys().size());
mode[m1.first] = 1; EXPECT_LONGS_EQUAL(2, hybridFactorB.continuousKeys().size());
mode[m2.first] = 2; EXPECT_LONGS_EQUAL(1, hybridFactorB.discreteKeys().size());
auto actual = sum(mode);
EXPECT(actual.at(0) == f11);
EXPECT(actual.at(1) == f22);
} }
/* ************************************************************************* */ /* ************************************************************************* */

View File

@ -29,6 +29,7 @@
#include <gtsam/hybrid/HybridGaussianConditional.h> #include <gtsam/hybrid/HybridGaussianConditional.h>
#include <gtsam/hybrid/HybridGaussianFactor.h> #include <gtsam/hybrid/HybridGaussianFactor.h>
#include <gtsam/hybrid/HybridGaussianFactorGraph.h> #include <gtsam/hybrid/HybridGaussianFactorGraph.h>
#include <gtsam/hybrid/HybridGaussianProductFactor.h>
#include <gtsam/hybrid/HybridGaussianISAM.h> #include <gtsam/hybrid/HybridGaussianISAM.h>
#include <gtsam/hybrid/HybridValues.h> #include <gtsam/hybrid/HybridValues.h>
#include <gtsam/inference/BayesNet.h> #include <gtsam/inference/BayesNet.h>
@ -674,14 +675,14 @@ TEST(HybridGaussianFactorGraph, IncrementalErrorTree) {
/* ****************************************************************************/ /* ****************************************************************************/
// Check that assembleGraphTree assembles Gaussian factor graphs for each // Check that assembleGraphTree assembles Gaussian factor graphs for each
// assignment. // assignment.
TEST(HybridGaussianFactorGraph, assembleGraphTree) { TEST(HybridGaussianFactorGraph, collectProductFactor) {
const int num_measurements = 1; const int num_measurements = 1;
auto fg = tiny::createHybridGaussianFactorGraph( auto fg = tiny::createHybridGaussianFactorGraph(
num_measurements, VectorValues{{Z(0), Vector1(5.0)}}); num_measurements, VectorValues{{Z(0), Vector1(5.0)}});
EXPECT_LONGS_EQUAL(3, fg.size()); EXPECT_LONGS_EQUAL(3, fg.size());
// Assemble graph tree: // Assemble graph tree:
auto actual = fg.assembleGraphTree(); auto actual = fg.collectProductFactor();
// Create expected decision tree with two factor graphs: // Create expected decision tree with two factor graphs:
@ -701,9 +702,9 @@ TEST(HybridGaussianFactorGraph, assembleGraphTree) {
// Expected decision tree with two factor graphs: // Expected decision tree with two factor graphs:
// f(x0;mode=0)P(x0) and f(x0;mode=1)P(x0) // f(x0;mode=0)P(x0) and f(x0;mode=1)P(x0)
GaussianFactorGraphTree expected{ HybridGaussianProductFactor expected{
M(0), GaussianFactorGraph(std::vector<GF>{(*hybrid)(d0), prior}), {M(0), GaussianFactorGraph(std::vector<GF>{(*hybrid)(d0), prior}),
GaussianFactorGraph(std::vector<GF>{(*hybrid)(d1), prior})}; GaussianFactorGraph(std::vector<GF>{(*hybrid)(d1), prior})}};
EXPECT(assert_equal(expected(d0), actual(d0), 1e-5)); EXPECT(assert_equal(expected(d0), actual(d0), 1e-5));
EXPECT(assert_equal(expected(d1), actual(d1), 1e-5)); EXPECT(assert_equal(expected(d1), actual(d1), 1e-5));