Three constructor variants extract keys

release/4.3a0
Frank Dellaert 2024-09-26 12:53:50 -07:00
parent e8089dc7cb
commit 35dd42ed23
3 changed files with 134 additions and 40 deletions

View File

@ -140,6 +140,10 @@ class GTSAM_EXPORT HybridFactor : public Factor {
/// @}
protected:
/// protected constructor to initialize the category
HybridFactor(Category category) : category_(category) {}
private:
#ifdef GTSAM_ENABLE_BOOST_SERIALIZATION
/** Serialization function */

View File

@ -26,18 +26,11 @@
#include <gtsam/linear/GaussianFactor.h>
#include <gtsam/linear/GaussianFactorGraph.h>
#include "gtsam/hybrid/HybridFactor.h"
namespace gtsam {
/**
* @brief Helper function to augment the [A|b] matrices in the factor components
* with the additional scalar values.
* This is done by storing the value in
* the `b` vector as an additional row.
*
* @param factors DecisionTree of GaussianFactors and arbitrary scalars.
* Gaussian factor in factors.
* @return HybridGaussianFactor::Factors
*/
/* *******************************************************************************/
HybridGaussianFactor::Factors HybridGaussianFactor::augment(
const FactorValuePairs &factors) {
// Find the minimum value so we can "proselytize" to positive values.
@ -76,13 +69,111 @@ HybridGaussianFactor::Factors HybridGaussianFactor::augment(
return Factors(factors, update);
}
/* *******************************************************************************/
HybridGaussianFactor::HybridGaussianFactor(
const DiscreteKey &discreteKey,
const std::vector<GaussianFactor::shared_ptr> &factors)
: Base(HybridFactor::Category::Hybrid) {
// Extract continuous keys from first-null factor and verify all others
KeyVector continuousKeys;
for (const auto &factor : factors) {
if (!factor) continue;
if (continuousKeys.empty()) {
continuousKeys = factor->keys();
} else if (factor->keys() != continuousKeys) {
throw std::invalid_argument("All factors must have the same keys");
}
}
// Check that this worked.
if (continuousKeys.empty()) {
throw std::invalid_argument("Need at least one non-null factor.");
}
// Initialize the base class
Factor::keys_ = continuousKeys;
Factor::keys_.push_back(discreteKey.first);
Base::discreteKeys_ = {discreteKey};
Base::continuousKeys_ = continuousKeys;
// Build the DecisionTree from factor vector
factors_ = Factors({discreteKey}, factors);
}
/* *******************************************************************************/
HybridGaussianFactor::HybridGaussianFactor(
const DiscreteKey &discreteKey,
const std::vector<GaussianFactorValuePair> &factorPairs)
: Base(HybridFactor::Category::Hybrid) {
// Extract continuous keys from first-null factor and verify all others
KeyVector continuousKeys;
for (const auto &pair : factorPairs) {
if (!pair.first) continue;
if (continuousKeys.empty()) {
continuousKeys = pair.first->keys();
} else if (pair.first->keys() != continuousKeys) {
throw std::invalid_argument("All factors must have the same keys");
}
}
// Check that this worked.
if (continuousKeys.empty()) {
throw std::invalid_argument("Need at least one non-null factor.");
}
// Initialize the base class
Factor::keys_ = continuousKeys;
Factor::keys_.push_back(discreteKey.first);
Base::discreteKeys_ = {discreteKey};
Base::continuousKeys_ = continuousKeys;
// Build the FactorValuePairs DecisionTree
FactorValuePairs pairTree({discreteKey}, factorPairs);
// Assign factors_ after calling augment
factors_ = augment(pairTree);
}
/* *******************************************************************************/
HybridGaussianFactor::HybridGaussianFactor(const DiscreteKeys &discreteKeys,
const FactorValuePairs &factorPairs)
: Base(HybridFactor::Category::Hybrid) {
// Verify that all factors have the same keys
KeyVector continuousKeys;
factorPairs.visit([&](const GaussianFactorValuePair &pair) {
if (pair.first) {
if (continuousKeys.empty()) {
continuousKeys = pair.first->keys();
} else if (pair.first->keys() != continuousKeys) {
throw std::invalid_argument("All factors must have the same keys");
}
}
});
// Check that this worked.
if (continuousKeys.empty()) {
throw std::invalid_argument("Need at least one non-null factor.");
}
// Initialize the base class
Factor::keys_ = continuousKeys;
for (const auto &discreteKey : discreteKeys) {
Factor::keys_.push_back(discreteKey.first);
}
Base::discreteKeys_ = discreteKeys;
Base::continuousKeys_ = continuousKeys;
// Assign factors_ after calling augment
factors_ = augment(factorPairs);
}
/* *******************************************************************************/
bool HybridGaussianFactor::equals(const HybridFactor &lf, double tol) const {
const This *e = dynamic_cast<const This *>(&lf);
if (e == nullptr) return false;
// This will return false if either factors_ is empty or e->factors_ is empty,
// but not if both are empty or both are not empty:
// This will return false if either factors_ is empty or e->factors_ is
// empty, but not if both are empty or both are not empty:
if (factors_.empty() ^ e->factors_.empty()) return false;
// Check the base and the factors:

View File

@ -73,17 +73,6 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
/// Decision tree of Gaussian factors indexed by discrete keys.
Factors factors_;
/// Helper function to "hide" the constants in the Jacobian factors.
static Factors augment(const FactorValuePairs &factors);
/**
* @brief Helper function to return factors and functional to create a
* DecisionTree of Gaussian Factor Graphs.
*
* @return GaussianFactorGraphTree
*/
GaussianFactorGraphTree asGaussianFactorGraphTree() const;
public:
/// @name Constructors
/// @{
@ -96,14 +85,11 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
* providing the factors for each mode m as a vector of factors ϕ_m(x).
* The value ϕ(x,m) for the factor is simply ϕ_m(x).
*
* @param continuousKeys Vector of keys for continuous factors.
* @param discreteKey The discrete key for the "mode", indexing components.
* @param factors Vector of gaussian factors, one for each mode.
*/
HybridGaussianFactor(const KeyVector &continuousKeys,
const DiscreteKey &discreteKey,
const std::vector<GaussianFactor::shared_ptr> &factors)
: Base(continuousKeys, {discreteKey}), factors_({discreteKey}, factors) {}
HybridGaussianFactor(const DiscreteKey &discreteKey,
const std::vector<GaussianFactor::shared_ptr> &factors);
/**
* @brief Construct a new HybridGaussianFactor on a single discrete key,
@ -111,15 +97,11 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
* provided as a vector of pairs (ϕ_m(x), E_m).
* The value ϕ(x,m) for the factor is now ϕ_m(x) + E_m.
*
* @param continuousKeys Vector of keys for continuous factors.
* @param discreteKey The discrete key for the "mode", indexing components.
* @param factors Vector of gaussian factor-scalar pairs, one per mode.
* @param factorPairs Vector of gaussian factor-scalar pairs, one per mode.
*/
HybridGaussianFactor(const KeyVector &continuousKeys,
const DiscreteKey &discreteKey,
const std::vector<GaussianFactorValuePair> &factors)
: HybridGaussianFactor(continuousKeys, {discreteKey},
FactorValuePairs({discreteKey}, factors)) {}
HybridGaussianFactor(const DiscreteKey &discreteKey,
const std::vector<GaussianFactorValuePair> &factorPairs);
/**
* @brief Construct a new HybridGaussianFactor on a several discrete keys M,
@ -127,14 +109,11 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
* scalars are provided as a DecisionTree<Key> of pairs (ϕ_M(x), E_M).
* The value ϕ(x,M) for the factor is again ϕ_m(x) + E_m.
*
* @param continuousKeys A vector of keys representing continuous variables.
* @param discreteKeys Discrete variables and their cardinalities.
* @param factors The decision tree of Gaussian factor/scalar pairs.
*/
HybridGaussianFactor(const KeyVector &continuousKeys,
const DiscreteKeys &discreteKeys,
const FactorValuePairs &factors)
: Base(continuousKeys, discreteKeys), factors_(augment(factors)) {}
HybridGaussianFactor(const DiscreteKeys &discreteKeys,
const FactorValuePairs &factors);
/// @}
/// @name Testable
@ -189,7 +168,27 @@ class GTSAM_EXPORT HybridGaussianFactor : public HybridFactor {
}
/// @}
protected:
/**
* @brief Helper function to return factors and functional to create a
* DecisionTree of Gaussian Factor Graphs.
*
* @return GaussianFactorGraphTree
*/
GaussianFactorGraphTree asGaussianFactorGraphTree() const;
private:
/**
* @brief Helper function to augment the [A|b] matrices in the factor
* components with the additional scalar values. This is done by storing the
* value in the `b` vector as an additional row.
*
* @param factors DecisionTree of GaussianFactors and arbitrary scalars.
* Gaussian factor in factors.
* @return HybridGaussianFactor::Factors
*/
static Factors augment(const FactorValuePairs &factors);
/// Helper method to compute the error of a component.
double potentiallyPrunedComponentError(
const sharedFactor &gf, const VectorValues &continuousValues) const;