New product factor class
parent
1bb5b9551b
commit
8b3dfd85e7
|
|
@ -18,17 +18,17 @@
|
|||
* @date Mar 12, 2022
|
||||
*/
|
||||
|
||||
#include <gtsam/base/types.h>
|
||||
#include <gtsam/base/utilities.h>
|
||||
#include <gtsam/discrete/DecisionTree-inl.h>
|
||||
#include <gtsam/discrete/DecisionTree.h>
|
||||
#include <gtsam/hybrid/HybridFactor.h>
|
||||
#include <gtsam/hybrid/HybridGaussianFactor.h>
|
||||
#include <gtsam/hybrid/HybridGaussianProductFactor.h>
|
||||
#include <gtsam/hybrid/HybridValues.h>
|
||||
#include <gtsam/linear/GaussianFactor.h>
|
||||
#include <gtsam/linear/GaussianFactorGraph.h>
|
||||
|
||||
#include "gtsam/base/types.h"
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
/* *******************************************************************************/
|
||||
|
|
@ -215,6 +215,12 @@ GaussianFactorGraphTree HybridGaussianFactor::asGaussianFactorGraphTree()
|
|||
return {factors_, wrap};
|
||||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
HybridGaussianProductFactor HybridGaussianFactor::asProductFactor() const {
|
||||
return {{factors_,
|
||||
[](const auto &pair) { return GaussianFactorGraph{pair.first}; }}};
|
||||
}
|
||||
|
||||
/* *******************************************************************************/
|
||||
/// Helper method to compute the error of a component.
|
||||
static double PotentiallyPrunedComponentError(
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@
|
|||
#include <gtsam/discrete/DecisionTree.h>
|
||||
#include <gtsam/discrete/DiscreteKey.h>
|
||||
#include <gtsam/hybrid/HybridFactor.h>
|
||||
#include <gtsam/hybrid/HybridGaussianProductFactor.h>
|
||||
#include <gtsam/linear/GaussianFactor.h>
|
||||
#include <gtsam/linear/GaussianFactorGraph.h>
|
||||
|
||||
|
|
@ -164,6 +165,14 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
|
|||
sum = factor.add(sum);
|
||||
return sum;
|
||||
}
|
||||
|
||||
/**
|
||||
* @brief Helper function to return factors and functional to create a
|
||||
* DecisionTree of Gaussian Factor Graphs.
|
||||
*
|
||||
* @return HybridGaussianProductFactor
|
||||
*/
|
||||
virtual HybridGaussianProductFactor asProductFactor() const;
|
||||
/// @}
|
||||
|
||||
protected:
|
||||
|
|
@ -175,7 +184,8 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
|
|||
*/
|
||||
GaussianFactorGraphTree asGaussianFactorGraphTree() const;
|
||||
|
||||
private:
|
||||
|
||||
private:
|
||||
/**
|
||||
* @brief Helper function to augment the [A|b] matrices in the factor
|
||||
* components with the additional scalar values. This is done by storing the
|
||||
|
|
|
|||
|
|
@ -0,0 +1,89 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* @file HybridGaussianProductFactor.h
|
||||
* @date Oct 2, 2024
|
||||
* @author Frank Dellaert
|
||||
* @author Varun Agrawal
|
||||
*/
|
||||
|
||||
#include <gtsam/base/types.h>
|
||||
#include <gtsam/discrete/DecisionTree.h>
|
||||
#include <gtsam/hybrid/HybridGaussianFactor.h>
|
||||
#include <gtsam/hybrid/HybridGaussianProductFactor.h>
|
||||
#include <gtsam/linear/GaussianFactorGraph.h>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
static GaussianFactorGraph add(const GaussianFactorGraph &graph1,
|
||||
const GaussianFactorGraph &graph2) {
|
||||
auto result = graph1;
|
||||
result.push_back(graph2);
|
||||
return result;
|
||||
};
|
||||
|
||||
HybridGaussianProductFactor operator+(const HybridGaussianProductFactor &a,
|
||||
const HybridGaussianProductFactor &b) {
|
||||
return a.empty() ? b : HybridGaussianProductFactor(a.apply(b, add));
|
||||
}
|
||||
|
||||
HybridGaussianProductFactor HybridGaussianProductFactor::operator+(
|
||||
const HybridGaussianFactor &factor) const {
|
||||
return *this + factor.asProductFactor();
|
||||
}
|
||||
|
||||
HybridGaussianProductFactor HybridGaussianProductFactor::operator+(
|
||||
const GaussianFactor::shared_ptr &factor) const {
|
||||
return *this + HybridGaussianProductFactor(factor);
|
||||
}
|
||||
|
||||
HybridGaussianProductFactor &HybridGaussianProductFactor::operator+=(
|
||||
const GaussianFactor::shared_ptr &factor) {
|
||||
*this = *this + factor;
|
||||
return *this;
|
||||
}
|
||||
|
||||
HybridGaussianProductFactor &
|
||||
HybridGaussianProductFactor::operator+=(const HybridGaussianFactor &factor) {
|
||||
*this = *this + factor;
|
||||
return *this;
|
||||
}
|
||||
|
||||
void HybridGaussianProductFactor::print(const std::string &s,
|
||||
const KeyFormatter &formatter) const {
|
||||
KeySet keys;
|
||||
auto printer = [&](const Y &graph) {
|
||||
if (keys.size() == 0)
|
||||
keys = graph.keys();
|
||||
return "Graph of size " + std::to_string(graph.size());
|
||||
};
|
||||
Base::print(s, formatter, printer);
|
||||
if (keys.size() > 0) {
|
||||
std::stringstream ss;
|
||||
ss << s << " Keys:";
|
||||
for (auto &&key : keys)
|
||||
ss << " " << formatter(key);
|
||||
std::cout << ss.str() << "." << std::endl;
|
||||
}
|
||||
}
|
||||
|
||||
HybridGaussianProductFactor HybridGaussianProductFactor::removeEmpty() const {
|
||||
auto emptyGaussian = [](const GaussianFactorGraph &graph) {
|
||||
bool hasNull =
|
||||
std::any_of(graph.begin(), graph.end(),
|
||||
[](const GaussianFactor::shared_ptr &ptr) { return !ptr; });
|
||||
return hasNull ? GaussianFactorGraph() : graph;
|
||||
};
|
||||
return {Base(*this, emptyGaussian)};
|
||||
}
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
@ -0,0 +1,117 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* @file HybridGaussianProductFactor.h
|
||||
* @date Oct 2, 2024
|
||||
* @author Frank Dellaert
|
||||
* @author Varun Agrawal
|
||||
*/
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <gtsam/discrete/DecisionTree.h>
|
||||
#include <gtsam/inference/Key.h>
|
||||
#include <gtsam/linear/GaussianFactorGraph.h>
|
||||
|
||||
namespace gtsam {
|
||||
|
||||
class HybridGaussianFactor;
|
||||
|
||||
/// Alias for DecisionTree of GaussianFactorGraphs
|
||||
class HybridGaussianProductFactor : public DecisionTree<Key, GaussianFactorGraph> {
|
||||
public:
|
||||
using Y = GaussianFactorGraph;
|
||||
using Base = DecisionTree<Key, Y>;
|
||||
|
||||
/// @name Constructors
|
||||
/// @{
|
||||
|
||||
/// Default constructor
|
||||
HybridGaussianProductFactor() = default;
|
||||
|
||||
/**
|
||||
* @brief Construct from a single factor
|
||||
* @tparam FACTOR Factor type
|
||||
* @param factor Shared pointer to the factor
|
||||
*/
|
||||
template <class FACTOR>
|
||||
HybridGaussianProductFactor(const std::shared_ptr<FACTOR>& factor) : Base(Y{factor}) {}
|
||||
|
||||
/**
|
||||
* @brief Construct from DecisionTree
|
||||
* @param tree Decision tree to construct from
|
||||
*/
|
||||
HybridGaussianProductFactor(Base&& tree) : Base(std::move(tree)) {}
|
||||
|
||||
///@}
|
||||
|
||||
/// @name Operators
|
||||
///@{
|
||||
|
||||
/// Add GaussianFactor into HybridGaussianProductFactor
|
||||
HybridGaussianProductFactor operator+(const GaussianFactor::shared_ptr& factor) const;
|
||||
|
||||
/// Add HybridGaussianFactor into HybridGaussianProductFactor
|
||||
HybridGaussianProductFactor operator+(const HybridGaussianFactor& factor) const;
|
||||
|
||||
/// Add-assign operator for GaussianFactor
|
||||
HybridGaussianProductFactor& operator+=(const GaussianFactor::shared_ptr& factor);
|
||||
|
||||
/// Add-assign operator for HybridGaussianFactor
|
||||
HybridGaussianProductFactor& operator+=(const HybridGaussianFactor& factor);
|
||||
|
||||
///@}
|
||||
|
||||
/// @name Testable
|
||||
/// @{
|
||||
|
||||
/**
|
||||
* @brief Print the HybridGaussianProductFactor
|
||||
* @param s Optional string to prepend
|
||||
* @param formatter Optional key formatter
|
||||
*/
|
||||
void print(const std::string& s = "", const KeyFormatter& formatter = DefaultKeyFormatter) const;
|
||||
|
||||
/**
|
||||
* @brief Check if this HybridGaussianProductFactor is equal to another
|
||||
* @param other The other HybridGaussianProductFactor to compare with
|
||||
* @param tol Tolerance for floating point comparisons
|
||||
* @return true if equal, false otherwise
|
||||
*/
|
||||
bool equals(const HybridGaussianProductFactor& other, double tol = 1e-9) const {
|
||||
return Base::equals(other, [tol](const Y& a, const Y& b) { return a.equals(b, tol); });
|
||||
}
|
||||
|
||||
/// @}
|
||||
|
||||
/// @name Other methods
|
||||
///@{
|
||||
|
||||
/**
|
||||
* @brief Remove empty GaussianFactorGraphs from the decision tree
|
||||
* @return A new HybridGaussianProductFactor with empty GaussianFactorGraphs removed
|
||||
*
|
||||
* If any GaussianFactorGraph in the decision tree contains a nullptr, convert
|
||||
* that leaf to an empty GaussianFactorGraph. This is needed because the DecisionTree
|
||||
* will otherwise create a GaussianFactorGraph with a single (null) factor,
|
||||
* which doesn't register as null.
|
||||
*/
|
||||
HybridGaussianProductFactor removeEmpty() const;
|
||||
|
||||
///@}
|
||||
};
|
||||
|
||||
// Testable traits
|
||||
template <>
|
||||
struct traits<HybridGaussianProductFactor> : public Testable<HybridGaussianProductFactor> {};
|
||||
|
||||
} // namespace gtsam
|
||||
|
|
@ -0,0 +1,185 @@
|
|||
/* ----------------------------------------------------------------------------
|
||||
|
||||
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
|
||||
* Atlanta, Georgia 30332-0415
|
||||
* All Rights Reserved
|
||||
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
|
||||
|
||||
* See LICENSE for the license information
|
||||
|
||||
* -------------------------------------------------------------------------- */
|
||||
|
||||
/**
|
||||
* @file testHybridGaussianProductFactor.cpp
|
||||
* @brief Unit tests for HybridGaussianProductFactor
|
||||
* @author Frank Dellaert
|
||||
* @date October 2024
|
||||
*/
|
||||
|
||||
#include "gtsam/inference/Key.h"
|
||||
#include <gtsam/base/Testable.h>
|
||||
#include <gtsam/base/TestableAssertions.h>
|
||||
#include <gtsam/hybrid/HybridGaussianFactor.h>
|
||||
#include <gtsam/hybrid/HybridGaussianProductFactor.h>
|
||||
#include <gtsam/inference/Symbol.h>
|
||||
#include <gtsam/linear/GaussianConditional.h>
|
||||
#include <gtsam/linear/JacobianFactor.h>
|
||||
|
||||
// Include for test suite
|
||||
#include <CppUnitLite/TestHarness.h>
|
||||
|
||||
#include <memory>
|
||||
|
||||
using namespace std;
|
||||
using namespace gtsam;
|
||||
using symbol_shorthand::M;
|
||||
using symbol_shorthand::X;
|
||||
using symbol_shorthand::Z;
|
||||
|
||||
/* ************************************************************************* */
|
||||
namespace examples {
|
||||
static const DiscreteKey m1(M(1), 2), m2(M(2), 3);
|
||||
|
||||
auto A1 = Matrix::Zero(2, 1);
|
||||
auto A2 = Matrix::Zero(2, 2);
|
||||
auto b = Matrix::Zero(2, 1);
|
||||
|
||||
auto f10 = std::make_shared<JacobianFactor>(X(1), A1, X(2), A2, b);
|
||||
auto f11 = std::make_shared<JacobianFactor>(X(1), A1, X(2), A2, b);
|
||||
|
||||
auto A3 = Matrix::Zero(2, 3);
|
||||
auto f20 = std::make_shared<JacobianFactor>(X(1), A1, X(3), A3, b);
|
||||
auto f21 = std::make_shared<JacobianFactor>(X(1), A1, X(3), A3, b);
|
||||
auto f22 = std::make_shared<JacobianFactor>(X(1), A1, X(3), A3, b);
|
||||
|
||||
HybridGaussianFactor hybridFactorA(m1, {f10, f11});
|
||||
HybridGaussianFactor hybridFactorB(m2, {f20, f21, f22});
|
||||
// Simulate a pruned hybrid factor, in this case m2==1 is nulled out.
|
||||
HybridGaussianFactor prunedFactorB(m2, {f20, nullptr, f22});
|
||||
} // namespace examples
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Constructor
|
||||
TEST(HybridGaussianProductFactor, Construct) {
|
||||
HybridGaussianProductFactor product;
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Add two Gaussian factors and check only one leaf in tree
|
||||
TEST(HybridGaussianProductFactor, AddTwoGaussianFactors) {
|
||||
using namespace examples;
|
||||
|
||||
HybridGaussianProductFactor product;
|
||||
product += f10;
|
||||
product += f11;
|
||||
|
||||
// Check that the product has only one leaf and no discrete variables.
|
||||
EXPECT_LONGS_EQUAL(1, product.nrLeaves());
|
||||
EXPECT(product.labels().empty());
|
||||
|
||||
// Retrieve the single leaf
|
||||
auto leaf = product(Assignment<Key>());
|
||||
|
||||
// Check that the leaf contains both factors
|
||||
EXPECT_LONGS_EQUAL(2, leaf.size());
|
||||
EXPECT(leaf.at(0) == f10);
|
||||
EXPECT(leaf.at(1) == f11);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Add two GaussianConditionals and check the resulting tree
|
||||
TEST(HybridGaussianProductFactor, AddTwoGaussianConditionals) {
|
||||
// Create two GaussianConditionals
|
||||
Vector1 d(1.0);
|
||||
Matrix11 R = I_1x1, S = I_1x1;
|
||||
auto gc1 = std::make_shared<GaussianConditional>(X(1), d, R, X(2), S);
|
||||
auto gc2 = std::make_shared<GaussianConditional>(X(2), d, R);
|
||||
|
||||
// Create a HybridGaussianProductFactor and add the conditionals
|
||||
HybridGaussianProductFactor product;
|
||||
product += std::static_pointer_cast<GaussianFactor>(gc1);
|
||||
product += std::static_pointer_cast<GaussianFactor>(gc2);
|
||||
|
||||
// Check that the product has only one leaf and no discrete variables
|
||||
EXPECT_LONGS_EQUAL(1, product.nrLeaves());
|
||||
EXPECT(product.labels().empty());
|
||||
|
||||
// Retrieve the single leaf
|
||||
auto leaf = product(Assignment<Key>());
|
||||
|
||||
// Check that the leaf contains both conditionals
|
||||
EXPECT_LONGS_EQUAL(2, leaf.size());
|
||||
EXPECT(leaf.at(0) == gc1);
|
||||
EXPECT(leaf.at(1) == gc2);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// Check AsProductFactor
|
||||
TEST(HybridGaussianProductFactor, AsProductFactor) {
|
||||
using namespace examples;
|
||||
auto product = hybridFactorA.asProductFactor();
|
||||
|
||||
// Let's check that this worked:
|
||||
Assignment<Key> mode;
|
||||
mode[m1.first] = 1;
|
||||
auto actual = product(mode);
|
||||
EXPECT(actual.at(0) == f11);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// "Add" one hybrid factors together.
|
||||
TEST(HybridGaussianProductFactor, AddOne) {
|
||||
using namespace examples;
|
||||
HybridGaussianProductFactor product;
|
||||
product += hybridFactorA;
|
||||
|
||||
// Let's check that this worked:
|
||||
Assignment<Key> mode;
|
||||
mode[m1.first] = 1;
|
||||
auto actual = product(mode);
|
||||
EXPECT(actual.at(0) == f11);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// "Add" two HFG together.
|
||||
TEST(HybridGaussianProductFactor, AddTwo) {
|
||||
using namespace examples;
|
||||
|
||||
// Create product of two hybrid factors: it will be a decision tree now on
|
||||
// both discrete variables m1 and m2:
|
||||
HybridGaussianProductFactor product;
|
||||
product += hybridFactorA;
|
||||
product += hybridFactorB;
|
||||
|
||||
// Let's check that this worked:
|
||||
auto actual00 = product({{M(1), 0}, {M(2), 0}});
|
||||
EXPECT(actual00.at(0) == f10);
|
||||
EXPECT(actual00.at(1) == f20);
|
||||
|
||||
auto actual12 = product({{M(1), 1}, {M(2), 2}});
|
||||
EXPECT(actual12.at(0) == f11);
|
||||
EXPECT(actual12.at(1) == f22);
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
// "Add" two HFG together.
|
||||
TEST(HybridGaussianProductFactor, AddPruned) {
|
||||
using namespace examples;
|
||||
|
||||
// Create product of two hybrid factors: it will be a decision tree now on
|
||||
// both discrete variables m1 and m2:
|
||||
HybridGaussianProductFactor product;
|
||||
product += hybridFactorA;
|
||||
product += prunedFactorB;
|
||||
EXPECT_LONGS_EQUAL(6, product.nrLeaves());
|
||||
|
||||
auto pruned = product.removeEmpty();
|
||||
EXPECT_LONGS_EQUAL(5, pruned.nrLeaves());
|
||||
}
|
||||
|
||||
/* ************************************************************************* */
|
||||
int main() {
|
||||
TestResult tr;
|
||||
return TestRegistry::runAllTests(tr);
|
||||
}
|
||||
/* ************************************************************************* */
|
||||
Loading…
Reference in New Issue