New product factor class

release/4.3a0
Frank Dellaert 2024-10-02 08:39:18 -07:00
parent 1bb5b9551b
commit 8b3dfd85e7
5 changed files with 410 additions and 3 deletions

View File

@ -18,17 +18,17 @@
* @date Mar 12, 2022
*/
#include <gtsam/base/types.h>
#include <gtsam/base/utilities.h>
#include <gtsam/discrete/DecisionTree-inl.h>
#include <gtsam/discrete/DecisionTree.h>
#include <gtsam/hybrid/HybridFactor.h>
#include <gtsam/hybrid/HybridGaussianFactor.h>
#include <gtsam/hybrid/HybridGaussianProductFactor.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/linear/GaussianFactor.h>
#include <gtsam/linear/GaussianFactorGraph.h>
#include "gtsam/base/types.h"
namespace gtsam {
/* *******************************************************************************/
@ -215,6 +215,12 @@ GaussianFactorGraphTree HybridGaussianFactor::asGaussianFactorGraphTree()
return {factors_, wrap};
}
/* *******************************************************************************/
HybridGaussianProductFactor HybridGaussianFactor::asProductFactor() const {
return {{factors_,
[](const auto &pair) { return GaussianFactorGraph{pair.first}; }}};
}
/* *******************************************************************************/
/// Helper method to compute the error of a component.
static double PotentiallyPrunedComponentError(

View File

@ -24,6 +24,7 @@
#include <gtsam/discrete/DecisionTree.h>
#include <gtsam/discrete/DiscreteKey.h>
#include <gtsam/hybrid/HybridFactor.h>
#include <gtsam/hybrid/HybridGaussianProductFactor.h>
#include <gtsam/linear/GaussianFactor.h>
#include <gtsam/linear/GaussianFactorGraph.h>
@ -164,6 +165,14 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
sum = factor.add(sum);
return sum;
}
/**
* @brief Helper function to return factors and functional to create a
* DecisionTree of Gaussian Factor Graphs.
*
* @return HybridGaussianProductFactor
*/
virtual HybridGaussianProductFactor asProductFactor() const;
/// @}
protected:
@ -175,7 +184,8 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
*/
GaussianFactorGraphTree asGaussianFactorGraphTree() const;
private:
private:
/**
* @brief Helper function to augment the [A|b] matrices in the factor
* components with the additional scalar values. This is done by storing the

View File

@ -0,0 +1,89 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file HybridGaussianProductFactor.h
* @date Oct 2, 2024
* @author Frank Dellaert
* @author Varun Agrawal
*/
#include <gtsam/base/types.h>
#include <gtsam/discrete/DecisionTree.h>
#include <gtsam/hybrid/HybridGaussianFactor.h>
#include <gtsam/hybrid/HybridGaussianProductFactor.h>
#include <gtsam/linear/GaussianFactorGraph.h>
namespace gtsam {
static GaussianFactorGraph add(const GaussianFactorGraph &graph1,
const GaussianFactorGraph &graph2) {
auto result = graph1;
result.push_back(graph2);
return result;
};
HybridGaussianProductFactor operator+(const HybridGaussianProductFactor &a,
const HybridGaussianProductFactor &b) {
return a.empty() ? b : HybridGaussianProductFactor(a.apply(b, add));
}
HybridGaussianProductFactor HybridGaussianProductFactor::operator+(
const HybridGaussianFactor &factor) const {
return *this + factor.asProductFactor();
}
HybridGaussianProductFactor HybridGaussianProductFactor::operator+(
const GaussianFactor::shared_ptr &factor) const {
return *this + HybridGaussianProductFactor(factor);
}
HybridGaussianProductFactor &HybridGaussianProductFactor::operator+=(
const GaussianFactor::shared_ptr &factor) {
*this = *this + factor;
return *this;
}
HybridGaussianProductFactor &
HybridGaussianProductFactor::operator+=(const HybridGaussianFactor &factor) {
*this = *this + factor;
return *this;
}
void HybridGaussianProductFactor::print(const std::string &s,
const KeyFormatter &formatter) const {
KeySet keys;
auto printer = [&](const Y &graph) {
if (keys.size() == 0)
keys = graph.keys();
return "Graph of size " + std::to_string(graph.size());
};
Base::print(s, formatter, printer);
if (keys.size() > 0) {
std::stringstream ss;
ss << s << " Keys:";
for (auto &&key : keys)
ss << " " << formatter(key);
std::cout << ss.str() << "." << std::endl;
}
}
HybridGaussianProductFactor HybridGaussianProductFactor::removeEmpty() const {
auto emptyGaussian = [](const GaussianFactorGraph &graph) {
bool hasNull =
std::any_of(graph.begin(), graph.end(),
[](const GaussianFactor::shared_ptr &ptr) { return !ptr; });
return hasNull ? GaussianFactorGraph() : graph;
};
return {Base(*this, emptyGaussian)};
}
} // namespace gtsam

View File

@ -0,0 +1,117 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file HybridGaussianProductFactor.h
* @date Oct 2, 2024
* @author Frank Dellaert
* @author Varun Agrawal
*/
#pragma once
#include <gtsam/discrete/DecisionTree.h>
#include <gtsam/inference/Key.h>
#include <gtsam/linear/GaussianFactorGraph.h>
namespace gtsam {
class HybridGaussianFactor;
/// Alias for DecisionTree of GaussianFactorGraphs
class HybridGaussianProductFactor : public DecisionTree<Key, GaussianFactorGraph> {
public:
using Y = GaussianFactorGraph;
using Base = DecisionTree<Key, Y>;
/// @name Constructors
/// @{
/// Default constructor
HybridGaussianProductFactor() = default;
/**
* @brief Construct from a single factor
* @tparam FACTOR Factor type
* @param factor Shared pointer to the factor
*/
template <class FACTOR>
HybridGaussianProductFactor(const std::shared_ptr<FACTOR>& factor) : Base(Y{factor}) {}
/**
* @brief Construct from DecisionTree
* @param tree Decision tree to construct from
*/
HybridGaussianProductFactor(Base&& tree) : Base(std::move(tree)) {}
///@}
/// @name Operators
///@{
/// Add GaussianFactor into HybridGaussianProductFactor
HybridGaussianProductFactor operator+(const GaussianFactor::shared_ptr& factor) const;
/// Add HybridGaussianFactor into HybridGaussianProductFactor
HybridGaussianProductFactor operator+(const HybridGaussianFactor& factor) const;
/// Add-assign operator for GaussianFactor
HybridGaussianProductFactor& operator+=(const GaussianFactor::shared_ptr& factor);
/// Add-assign operator for HybridGaussianFactor
HybridGaussianProductFactor& operator+=(const HybridGaussianFactor& factor);
///@}
/// @name Testable
/// @{
/**
* @brief Print the HybridGaussianProductFactor
* @param s Optional string to prepend
* @param formatter Optional key formatter
*/
void print(const std::string& s = "", const KeyFormatter& formatter = DefaultKeyFormatter) const;
/**
* @brief Check if this HybridGaussianProductFactor is equal to another
* @param other The other HybridGaussianProductFactor to compare with
* @param tol Tolerance for floating point comparisons
* @return true if equal, false otherwise
*/
bool equals(const HybridGaussianProductFactor& other, double tol = 1e-9) const {
return Base::equals(other, [tol](const Y& a, const Y& b) { return a.equals(b, tol); });
}
/// @}
/// @name Other methods
///@{
/**
* @brief Remove empty GaussianFactorGraphs from the decision tree
* @return A new HybridGaussianProductFactor with empty GaussianFactorGraphs removed
*
* If any GaussianFactorGraph in the decision tree contains a nullptr, convert
* that leaf to an empty GaussianFactorGraph. This is needed because the DecisionTree
* will otherwise create a GaussianFactorGraph with a single (null) factor,
* which doesn't register as null.
*/
HybridGaussianProductFactor removeEmpty() const;
///@}
};
// Testable traits
template <>
struct traits<HybridGaussianProductFactor> : public Testable<HybridGaussianProductFactor> {};
} // namespace gtsam

View File

@ -0,0 +1,185 @@
/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file testHybridGaussianProductFactor.cpp
* @brief Unit tests for HybridGaussianProductFactor
* @author Frank Dellaert
* @date October 2024
*/
#include "gtsam/inference/Key.h"
#include <gtsam/base/Testable.h>
#include <gtsam/base/TestableAssertions.h>
#include <gtsam/hybrid/HybridGaussianFactor.h>
#include <gtsam/hybrid/HybridGaussianProductFactor.h>
#include <gtsam/inference/Symbol.h>
#include <gtsam/linear/GaussianConditional.h>
#include <gtsam/linear/JacobianFactor.h>
// Include for test suite
#include <CppUnitLite/TestHarness.h>
#include <memory>
using namespace std;
using namespace gtsam;
using symbol_shorthand::M;
using symbol_shorthand::X;
using symbol_shorthand::Z;
/* ************************************************************************* */
namespace examples {
static const DiscreteKey m1(M(1), 2), m2(M(2), 3);
auto A1 = Matrix::Zero(2, 1);
auto A2 = Matrix::Zero(2, 2);
auto b = Matrix::Zero(2, 1);
auto f10 = std::make_shared<JacobianFactor>(X(1), A1, X(2), A2, b);
auto f11 = std::make_shared<JacobianFactor>(X(1), A1, X(2), A2, b);
auto A3 = Matrix::Zero(2, 3);
auto f20 = std::make_shared<JacobianFactor>(X(1), A1, X(3), A3, b);
auto f21 = std::make_shared<JacobianFactor>(X(1), A1, X(3), A3, b);
auto f22 = std::make_shared<JacobianFactor>(X(1), A1, X(3), A3, b);
HybridGaussianFactor hybridFactorA(m1, {f10, f11});
HybridGaussianFactor hybridFactorB(m2, {f20, f21, f22});
// Simulate a pruned hybrid factor, in this case m2==1 is nulled out.
HybridGaussianFactor prunedFactorB(m2, {f20, nullptr, f22});
} // namespace examples
/* ************************************************************************* */
// Constructor
TEST(HybridGaussianProductFactor, Construct) {
HybridGaussianProductFactor product;
}
/* ************************************************************************* */
// Add two Gaussian factors and check only one leaf in tree
TEST(HybridGaussianProductFactor, AddTwoGaussianFactors) {
using namespace examples;
HybridGaussianProductFactor product;
product += f10;
product += f11;
// Check that the product has only one leaf and no discrete variables.
EXPECT_LONGS_EQUAL(1, product.nrLeaves());
EXPECT(product.labels().empty());
// Retrieve the single leaf
auto leaf = product(Assignment<Key>());
// Check that the leaf contains both factors
EXPECT_LONGS_EQUAL(2, leaf.size());
EXPECT(leaf.at(0) == f10);
EXPECT(leaf.at(1) == f11);
}
/* ************************************************************************* */
// Add two GaussianConditionals and check the resulting tree
TEST(HybridGaussianProductFactor, AddTwoGaussianConditionals) {
// Create two GaussianConditionals
Vector1 d(1.0);
Matrix11 R = I_1x1, S = I_1x1;
auto gc1 = std::make_shared<GaussianConditional>(X(1), d, R, X(2), S);
auto gc2 = std::make_shared<GaussianConditional>(X(2), d, R);
// Create a HybridGaussianProductFactor and add the conditionals
HybridGaussianProductFactor product;
product += std::static_pointer_cast<GaussianFactor>(gc1);
product += std::static_pointer_cast<GaussianFactor>(gc2);
// Check that the product has only one leaf and no discrete variables
EXPECT_LONGS_EQUAL(1, product.nrLeaves());
EXPECT(product.labels().empty());
// Retrieve the single leaf
auto leaf = product(Assignment<Key>());
// Check that the leaf contains both conditionals
EXPECT_LONGS_EQUAL(2, leaf.size());
EXPECT(leaf.at(0) == gc1);
EXPECT(leaf.at(1) == gc2);
}
/* ************************************************************************* */
// Check AsProductFactor
TEST(HybridGaussianProductFactor, AsProductFactor) {
using namespace examples;
auto product = hybridFactorA.asProductFactor();
// Let's check that this worked:
Assignment<Key> mode;
mode[m1.first] = 1;
auto actual = product(mode);
EXPECT(actual.at(0) == f11);
}
/* ************************************************************************* */
// "Add" one hybrid factors together.
TEST(HybridGaussianProductFactor, AddOne) {
using namespace examples;
HybridGaussianProductFactor product;
product += hybridFactorA;
// Let's check that this worked:
Assignment<Key> mode;
mode[m1.first] = 1;
auto actual = product(mode);
EXPECT(actual.at(0) == f11);
}
/* ************************************************************************* */
// "Add" two HFG together.
TEST(HybridGaussianProductFactor, AddTwo) {
using namespace examples;
// Create product of two hybrid factors: it will be a decision tree now on
// both discrete variables m1 and m2:
HybridGaussianProductFactor product;
product += hybridFactorA;
product += hybridFactorB;
// Let's check that this worked:
auto actual00 = product({{M(1), 0}, {M(2), 0}});
EXPECT(actual00.at(0) == f10);
EXPECT(actual00.at(1) == f20);
auto actual12 = product({{M(1), 1}, {M(2), 2}});
EXPECT(actual12.at(0) == f11);
EXPECT(actual12.at(1) == f22);
}
/* ************************************************************************* */
// "Add" two HFG together.
TEST(HybridGaussianProductFactor, AddPruned) {
using namespace examples;
// Create product of two hybrid factors: it will be a decision tree now on
// both discrete variables m1 and m2:
HybridGaussianProductFactor product;
product += hybridFactorA;
product += prunedFactorB;
EXPECT_LONGS_EQUAL(6, product.nrLeaves());
auto pruned = product.removeEmpty();
EXPECT_LONGS_EQUAL(5, pruned.nrLeaves());
}
/* ************************************************************************* */
int main() {
TestResult tr;
return TestRegistry::runAllTests(tr);
}
/* ************************************************************************* */