From 925e94ecc22bf5205b03f6c52bbb6697c79f3610 Mon Sep 17 00:00:00 2001 From: darshan-17 Date: Sun, 20 Apr 2025 21:41:44 -0700 Subject: [PATCH] Resolved the direction class -> Unit3, and local_coords, local_inv functions to be inline with the state class --- examples/ABC_EQF.h | 205 ++++++++++++++++++-------------------- examples/ABC_EQF_Demo.cpp | 7 +- 2 files changed, 102 insertions(+), 110 deletions(-) diff --git a/examples/ABC_EQF.h b/examples/ABC_EQF.h index 41cacbb19..341149cd3 100644 --- a/examples/ABC_EQF.h +++ b/examples/ABC_EQF.h @@ -71,26 +71,7 @@ Matrix numericalDifferential(std::function f, const Vecto // Core Data Types //======================================================================== -/// Direction class as a S2 element -class Direction { -public: - Unit3 d; // Direction (unit vector on S2) - - /** - * Initialize direction - * @param d_vec Direction vector (must be unit norm) - */ - Direction(const Vector3& d_vec); - - // Accessor methods for vector components - double x() const; - double y() const; - double z() const; - - // Check if the direction contains NaN values - bool hasNaN() const; -}; /// Input class for the Biased Attitude System @@ -118,8 +99,8 @@ struct Input { /// Measurement class struct Measurement { - Direction y; /// Measurement direction in sensor frame - Direction d; /// Known direction in global frame + Unit3 y; /// Measurement direction in sensor frame + Unit3 d; /// Known direction in global frame Matrix3 Sigma; /// Covariance matrix of the measurement int cal_idx = -1; /// Calibration index (-1 for calibrated sensor) Measurement(const Vector3& y_vec, const Vector3& d_vec, @@ -138,6 +119,53 @@ public: const Vector3& b = Vector3::Zero(), const std::vector& S = std::vector()); + /** + * Compute Local coordinates in the state relative to another state. + * @param other The other state + * @return Local coordinates in the tangent space + */ + Vector localCoordinates(const State& other) const { + Vector eps(6 + 3 * S.size()); + + // First 3 elements - attitude + eps.head<3>() = Rot3::Logmap(R.between(other.R)); + // Next 3 elements - bias + // Next 3 elements - bias + eps.segment<3>(3) = other.b - b; + + // Remaining elements - calibrations + for (size_t i = 0; i < S.size(); i++) { + eps.segment<3>(6 + 3*i) = Rot3::Logmap(S[i].between(other.S[i])); + } + + return eps; + } + + /** + * Retract from tangent space back to the manifold + * @param v Vector in the tangent space + * @return New state + */ + State retract(const Vector& v) const { + if (v.size() < 6 || v.size() % 3 != 0 || v.size() != 6 + 3 * static_cast(S.size())) { + throw std::invalid_argument("Vector size does not match state dimensions"); + } + + // Modify attitude + Rot3 newR = R * Rot3::Expmap(v.head<3>()); + + // Modify bias + Vector3 newB = b + v.segment<3>(3); + + // Modify calibrations + std::vector newS; + for (size_t i = 0; i < S.size(); i++) { + newS.push_back(S[i] * Rot3::Expmap(v.segment<3>(6 + 3*i))); + } + + return State(newR, newB, newS); + } + static State identity(int n); }; @@ -254,21 +282,7 @@ Input velocityAction(const G& X, const Input& u); * @param idx Calibration index * @return New direction after group action */ -Vector3 outputAction(const G& X, const Direction& y, int idx); - -/** - * Local coordinates assuming xi_0 = identity (Equation 9) - * @param e State representing equivariant error - * @return Local coordinates - */ -Vector local_coords(const State& e); - -/** - * Local coordinates inverse assuming xi_0 = identity - * @param eps Local coordinates - * @return Corresponding state - */ -State local_coords_inv(const Vector& eps); +Vector3 outputAction(const G& X, const Unit3& y, int idx); /** * Differential of the phi action at E = Id in local coordinates @@ -320,7 +334,7 @@ private: * @param idx Calibration index * @return Measurement matrix C0 */ - Matrix measurementMatrixC(const Direction& d, int idx) const; + Matrix measurementMatrixC(const Unit3& d, int idx) const; /** * Return the measurement output matrix Dt @@ -460,33 +474,6 @@ Matrix numericalDifferential(std::function f, const Vecto * direction in 3D space) Uses Unit3's constructor which normalizes vectors */ - -Direction::Direction(const Vector3& d_vec) : d(d_vec) { - if (!checkNorm(d_vec)) { - throw std::invalid_argument("Direction must be a unit vector"); - } -} - /** Access the individual components of the direction vector defined above using this methods below - * Uses the Unit3 object from GTSAM to compute the components - */ - -double Direction::x() const { - return d.unitVector()[0]; -} - -double Direction::y() const { - return d.unitVector()[1]; -} - -double Direction::z() const { - return d.unitVector()[2]; -} - -bool Direction::hasNaN() const { - Vector3 vec = d.unitVector(); - return std::isnan(vec[0]) || std::isnan(vec[1]) || std::isnan(vec[2]); -} - //======================================================================== // Input Class Implementation //======================================================================== @@ -758,14 +745,14 @@ Input velocityAction(const G& X, const Input& u) { * @return Transformed direction * Uses Rot3 inverse, matric and Unit3 unitvector functions */ -Vector3 outputAction(const G& X, const Direction& y, int idx) { +Vector3 outputAction(const G& X, const Unit3& y, int idx) { if (idx == -1) { - return X.A.inverse().matrix() * y.d.unitVector(); + return X.A.inverse().matrix() * y.unitVector(); } else { if (idx >= static_cast(X.B.size())) { throw std::out_of_range("Calibration index out of range"); } - return X.B[idx].inverse().matrix() * y.d.unitVector(); + return X.B[idx].inverse().matrix() * y.unitVector(); } } @@ -776,46 +763,46 @@ Vector3 outputAction(const G& X, const Direction& y, int idx) { * @return Vector with local coordinates * Uses Rot3 logamo for mapping rotations to the tangent space */ -Vector local_coords(const State& e) { - if (COORDINATE == "EXPONENTIAL") { - Vector eps(6 + 3 * e.S.size()); - - // First 3 elements - eps.head<3>() = Rot3::Logmap(e.R); - - // Next 3 elements - eps.segment<3>(3) = e.b; - - // Remaining elements - for (size_t i = 0; i < e.S.size(); i++) { - eps.segment<3>(6 + 3*i) = Rot3::Logmap(e.S[i]); - } - - return eps; - } else if (COORDINATE == "NORMAL") { - throw std::runtime_error("Normal coordinate representation is not implemented yet"); - } else { - throw std::invalid_argument("Invalid coordinate representation"); - } -} +// Vector local_coords(const State& e) { +// if (COORDINATE == "EXPONENTIAL") { +// Vector eps(6 + 3 * e.S.size()); +// +// // First 3 elements +// eps.head<3>() = Rot3::Logmap(e.R); +// +// // Next 3 elements +// eps.segment<3>(3) = e.b; +// +// // Remaining elements +// for (size_t i = 0; i < e.S.size(); i++) { +// eps.segment<3>(6 + 3*i) = Rot3::Logmap(e.S[i]); +// } +// +// return eps; +// } else if (COORDINATE == "NORMAL") { +// throw std::runtime_error("Normal coordinate representation is not implemented yet"); +// } else { +// throw std::invalid_argument("Invalid coordinate representation"); +// } +// } /** * Used to convert the vectorized errors back to state space * @param eps Local coordinates in the exponential parameterization * @return State object corresponding to these coordinates * Uses Rot3 expmap through the G::exp() function */ -State local_coords_inv(const Vector& eps) { - G X = G::exp(eps); - - if (COORDINATE == "EXPONENTIAL") { - std::vector S = X.B; - return State(X.A, eps.segment<3>(3), S); - } else if (COORDINATE == "NORMAL") { - throw std::runtime_error("Normal coordinate representation is not implemented yet"); - } else { - throw std::invalid_argument("Invalid coordinate representation"); - } -} +// State local_coords_inv(const Vector& eps) { +// G X = G::exp(eps); +// +// if (COORDINATE == "EXPONENTIAL") { +// std::vector S = X.B; +// return State(X.A, eps.segment<3>(3), S); +// } else if (COORDINATE == "NORMAL") { +// throw std::runtime_error("Normal coordinate representation is not implemented yet"); +// } else { +// throw std::invalid_argument("Invalid coordinate representation"); +// } +// } /** * Computes the differential of a state action at the identity of the symmetry * group @@ -827,7 +814,9 @@ State local_coords_inv(const Vector& eps) { Matrix stateActionDiff(const State& xi) { std::function coordsAction = [&xi](const Vector& U) { - return local_coords(stateAction(G::exp(U), xi)); + G groupElement = G::exp(U); + State transformed = stateAction(groupElement, xi); + return xi.localCoordinates(transformed); }; Vector zeros = Vector::Zero(6 + 3 * xi.S.size()); @@ -914,8 +903,8 @@ void EqF::update(const Measurement& y) { } // Get vector representations for checking - Vector3 y_vec = y.y.d.unitVector(); - Vector3 d_vec = y.d.d.unitVector(); + Vector3 y_vec = y.y.unitVector(); + Vector3 d_vec = y.d.unitVector(); // Skip update if any NaN values are present if (std::isnan(y_vec[0]) || std::isnan(y_vec[1]) || std::isnan(y_vec[2]) || @@ -925,7 +914,7 @@ void EqF::update(const Measurement& y) { Matrix Ct = measurementMatrixC(y.d, y.cal_idx); Vector3 action_result = outputAction(X_hat.inv(), y.y, y.cal_idx); - Vector3 delta_vec = Rot3::Hat(y.d.d.unitVector()) * action_result; + Vector3 delta_vec = Rot3::Hat(y.d.unitVector()) * action_result; Matrix Dt = outputMatrixDt(y.cal_idx); Matrix S = Ct * Sigma * Ct.transpose() + Dt * y.Sigma * Dt.transpose(); Matrix K = Sigma * Ct.transpose() * S.inverse(); @@ -1013,16 +1002,16 @@ Matrix EqF::inputMatrixBt() const { * @return Measurement matrix * Uses the matrix zero, Rot3 hat and the Unitvector functions */ -Matrix EqF::measurementMatrixC(const Direction& d, int idx) const { +Matrix EqF::measurementMatrixC(const Unit3& d, int idx) const { Matrix Cc = Matrix::Zero(3, 3 * n_cal); // If the measurement is related to a sensor that has a calibration state if (idx >= 0) { // Set the correct 3x3 block in Cc - Cc.block<3, 3>(0, 3 * idx) = Rot3::Hat(d.d.unitVector()); + Cc.block<3, 3>(0, 3 * idx) = Rot3::Hat(d.unitVector()); } - Matrix3 wedge_d = Rot3::Hat(d.d.unitVector()); + Matrix3 wedge_d = Rot3::Hat(d.unitVector()); // Create the combined matrix Matrix temp(3, 6 + 3 * n_cal); diff --git a/examples/ABC_EQF_Demo.cpp b/examples/ABC_EQF_Demo.cpp index a2267c3ac..e90283873 100644 --- a/examples/ABC_EQF_Demo.cpp +++ b/examples/ABC_EQF_Demo.cpp @@ -262,8 +262,11 @@ void processDataWithEqF(EqF& filter, const std::vector& data_list, int pri totalMeasurements++; // Skip invalid measurements - if (y.y.hasNaN() || y.d.hasNaN()) { - continue; + Vector3 y_vec = y.y.unitVector(); + Vector3 d_vec = y.d.unitVector(); + if (std::isnan(y_vec[0]) || std::isnan(y_vec[1]) || std::isnan(y_vec[2]) || + std::isnan(d_vec[0]) || std::isnan(d_vec[1]) || std::isnan(d_vec[2])) { + continue; } try {