Made HNF constructors like HGF

release/4.3a0
Frank Dellaert 2024-09-23 14:44:46 -07:00
parent fd7df61d45
commit 530b0ad742
3 changed files with 98 additions and 55 deletions

View File

@ -21,10 +21,57 @@
namespace gtsam {
/* *******************************************************************************/
HybridNonlinearFactor::HybridNonlinearFactor(const KeyVector& keys,
static void checkKeys(const KeyVector& continuousKeys,
std::vector<NonlinearFactorValuePair>& pairs) {
KeySet factor_keys_set;
for (const auto& pair : pairs) {
auto f = pair.first;
// Insert all factor continuous keys in the continuous keys set.
std::copy(f->keys().begin(), f->keys().end(),
std::inserter(factor_keys_set, factor_keys_set.end()));
}
KeySet continuous_keys_set(continuousKeys.begin(), continuousKeys.end());
if (continuous_keys_set != factor_keys_set) {
throw std::runtime_error(
"HybridNonlinearFactor: The specified continuous keys and the keys in "
"the factors do not match!");
}
}
/* *******************************************************************************/
HybridNonlinearFactor::HybridNonlinearFactor(
const KeyVector& continuousKeys, const DiscreteKey& discreteKey,
const std::vector<NonlinearFactor::shared_ptr>& factors)
: Base(continuousKeys, {discreteKey}) {
std::vector<NonlinearFactorValuePair> pairs;
for (auto&& f : factors) {
pairs.emplace_back(f, 0.0);
}
checkKeys(continuousKeys, pairs);
factors_ = FactorValuePairs({discreteKey}, pairs);
}
/* *******************************************************************************/
HybridNonlinearFactor::HybridNonlinearFactor(
const KeyVector& continuousKeys, const DiscreteKey& discreteKey,
const std::vector<NonlinearFactorValuePair>& factors)
: Base(continuousKeys, {discreteKey}) {
std::vector<NonlinearFactorValuePair> pairs;
KeySet continuous_keys_set(continuousKeys.begin(), continuousKeys.end());
KeySet factor_keys_set;
for (auto&& [f, val] : factors) {
pairs.emplace_back(f, val);
}
checkKeys(continuousKeys, pairs);
factors_ = FactorValuePairs({discreteKey}, pairs);
}
/* *******************************************************************************/
HybridNonlinearFactor::HybridNonlinearFactor(const KeyVector& continuousKeys,
const DiscreteKeys& discreteKeys,
const Factors& factors)
: Base(keys, discreteKeys), factors_(factors) {}
const FactorValuePairs& factors)
: Base(continuousKeys, discreteKeys), factors_(factors) {}
/* *******************************************************************************/
AlgebraicDecisionTree<Key> HybridNonlinearFactor::errorTree(

View File

@ -68,11 +68,11 @@ class GTSAM_EXPORT HybridNonlinearFactor : public HybridFactor {
* @brief typedef for DecisionTree which has Keys as node labels and
* pairs of NonlinearFactor & an arbitrary scalar as leaf nodes.
*/
using Factors = DecisionTree<Key, NonlinearFactorValuePair>;
using FactorValuePairs = DecisionTree<Key, NonlinearFactorValuePair>;
private:
/// Decision tree of Gaussian factors indexed by discrete keys.
Factors factors_;
/// Decision tree of nonlinear factors indexed by discrete keys.
FactorValuePairs factors_;
/// HybridFactor method implementation. Should not be used.
AlgebraicDecisionTree<Key> errorTree(
@ -82,62 +82,49 @@ class GTSAM_EXPORT HybridNonlinearFactor : public HybridFactor {
}
public:
/// Default constructor, mainly for serialization.
HybridNonlinearFactor() = default;
/**
* @brief Construct from Decision tree.
* @brief Construct a new HybridNonlinearFactor on a single discrete key,
* providing the factors for each mode m as a vector of factors ϕ_m(x).
* The value ϕ(x,m) for the factor is simply ϕ_m(x).
*
* @param keys Vector of keys for continuous factors.
* @param discreteKeys Vector of discrete keys.
* @param factors Decision tree with of shared factors.
* @param continuousKeys Vector of keys for continuous factors.
* @param discreteKey The discrete key for the "mode", indexing components.
* @param factors Vector of gaussian factors, one for each mode.
*/
HybridNonlinearFactor(const KeyVector& keys, const DiscreteKeys& discreteKeys,
const Factors& factors);
HybridNonlinearFactor(
const KeyVector& continuousKeys, const DiscreteKey& discreteKey,
const std::vector<NonlinearFactor::shared_ptr>& factors);
/**
* @brief Convenience constructor that generates the underlying factor
* decision tree for us.
* @brief Construct a new HybridNonlinearFactor on a single discrete key,
* including a scalar error value for each mode m. The factors and scalars are
* provided as a vector of pairs (ϕ_m(x), E_m).
* The value ϕ(x,m) for the factor is now ϕ_m(x) + E_m.
*
* Here it is important that the vector of factors has the correct number of
* elements based on the number of discrete keys and the cardinality of the
* keys, so that the decision tree is constructed appropriately.
*
* @tparam FACTOR The type of the factor shared pointers being passed in.
* Will be typecast to NonlinearFactor shared pointers.
* @param keys Vector of keys for continuous factors.
* @param discreteKey The discrete key indexing each component factor.
* @param factors Vector of nonlinear factor and scalar pairs.
* Same size as the cardinality of discreteKey.
* @param continuousKeys Vector of keys for continuous factors.
* @param discreteKey The discrete key for the "mode", indexing components.
* @param factors Vector of gaussian factor-scalar pairs, one per mode.
*/
template <typename FACTOR>
HybridNonlinearFactor(
const KeyVector& keys, const DiscreteKey& discreteKey,
const std::vector<std::pair<std::shared_ptr<FACTOR>, double>>& factors)
: Base(keys, {discreteKey}) {
std::vector<NonlinearFactorValuePair> nonlinear_factors;
KeySet continuous_keys_set(keys.begin(), keys.end());
KeySet factor_keys_set;
for (auto&& [f, val] : factors) {
// Insert all factor continuous keys in the continuous keys set.
std::copy(f->keys().begin(), f->keys().end(),
std::inserter(factor_keys_set, factor_keys_set.end()));
if (auto nf = std::dynamic_pointer_cast<NonlinearFactor>(f)) {
nonlinear_factors.emplace_back(nf, val);
} else {
throw std::runtime_error(
"Factors passed into HybridNonlinearFactor need to be nonlinear!");
}
}
factors_ = Factors({discreteKey}, nonlinear_factors);
if (continuous_keys_set != factor_keys_set) {
throw std::runtime_error(
"The specified continuous keys and the keys in the factors don't "
"match!");
}
}
HybridNonlinearFactor(const KeyVector& continuousKeys,
const DiscreteKey& discreteKey,
const std::vector<NonlinearFactorValuePair>& factors);
/**
* @brief Construct a new HybridNonlinearFactor on a several discrete keys M,
* including a scalar error value for each assignment m. The factors and
* scalars are provided as a DecisionTree<Key> of pairs (ϕ_M(x), E_M).
* The value ϕ(x,M) for the factor is again ϕ_m(x) + E_m.
*
* @param continuousKeys A vector of keys representing continuous variables.
* @param discreteKeys Discrete variables and their cardinalities.
* @param factors The decision tree of nonlinear factor/scalar pairs.
*/
HybridNonlinearFactor(const KeyVector& continuousKeys,
const DiscreteKeys& discreteKeys,
const FactorValuePairs& factors);
/**
* @brief Compute error of the HybridNonlinearFactor as a tree.
*
@ -196,4 +183,9 @@ class GTSAM_EXPORT HybridNonlinearFactor : public HybridFactor {
const Values& continuousValues) const;
};
// traits
template <>
struct traits<HybridNonlinearFactor> : public Testable<HybridNonlinearFactor> {
};
} // namespace gtsam

View File

@ -246,14 +246,18 @@ class HybridNonlinearFactorGraph {
#include <gtsam/hybrid/HybridNonlinearFactor.h>
class HybridNonlinearFactor : gtsam::HybridFactor {
HybridNonlinearFactor(
const gtsam::KeyVector& keys, const gtsam::DiscreteKeys& discreteKeys,
const gtsam::DecisionTree<
gtsam::Key, std::pair<gtsam::NonlinearFactor*, double>>& factors);
const gtsam::KeyVector& keys, const gtsam::DiscreteKey& discreteKey,
const std::vector<gtsam::NonlinearFactor*>& factors);
HybridNonlinearFactor(
const gtsam::KeyVector& keys, const gtsam::DiscreteKey& discreteKey,
const std::vector<std::pair<gtsam::NonlinearFactor*, double>>& factors);
HybridNonlinearFactor(
const gtsam::KeyVector& keys, const gtsam::DiscreteKeys& discreteKeys,
const gtsam::DecisionTree<
gtsam::Key, std::pair<gtsam::NonlinearFactor*, double>>& factors);
double error(const gtsam::Values& continuousValues,
const gtsam::DiscreteValues& discreteValues) const;