Resolved the direction class -> Unit3, and local_coords, local_inv functions to be inline with the state class

release/4.3a0
darshan-17 2025-04-20 21:41:44 -07:00
parent e5f4978539
commit 925e94ecc2
2 changed files with 102 additions and 110 deletions

View File

@ -71,26 +71,7 @@ Matrix numericalDifferential(std::function<Vector(const Vector&)> f, const Vecto
// Core Data Types
//========================================================================
/// Direction class as a S2 element
class Direction {
public:
Unit3 d; // Direction (unit vector on S2)
/**
* Initialize direction
* @param d_vec Direction vector (must be unit norm)
*/
Direction(const Vector3& d_vec);
// Accessor methods for vector components
double x() const;
double y() const;
double z() const;
// Check if the direction contains NaN values
bool hasNaN() const;
};
/// Input class for the Biased Attitude System
@ -118,8 +99,8 @@ struct Input {
/// Measurement class
struct Measurement {
Direction y; /// Measurement direction in sensor frame
Direction d; /// Known direction in global frame
Unit3 y; /// Measurement direction in sensor frame
Unit3 d; /// Known direction in global frame
Matrix3 Sigma; /// Covariance matrix of the measurement
int cal_idx = -1; /// Calibration index (-1 for calibrated sensor)
Measurement(const Vector3& y_vec, const Vector3& d_vec,
@ -138,6 +119,53 @@ public:
const Vector3& b = Vector3::Zero(),
const std::vector<Rot3>& S = std::vector<Rot3>());
/**
* Compute Local coordinates in the state relative to another state.
* @param other The other state
* @return Local coordinates in the tangent space
*/
Vector localCoordinates(const State& other) const {
Vector eps(6 + 3 * S.size());
// First 3 elements - attitude
eps.head<3>() = Rot3::Logmap(R.between(other.R));
// Next 3 elements - bias
// Next 3 elements - bias
eps.segment<3>(3) = other.b - b;
// Remaining elements - calibrations
for (size_t i = 0; i < S.size(); i++) {
eps.segment<3>(6 + 3*i) = Rot3::Logmap(S[i].between(other.S[i]));
}
return eps;
}
/**
* Retract from tangent space back to the manifold
* @param v Vector in the tangent space
* @return New state
*/
State retract(const Vector& v) const {
if (v.size() < 6 || v.size() % 3 != 0 || v.size() != 6 + 3 * static_cast<Eigen::Index>(S.size())) {
throw std::invalid_argument("Vector size does not match state dimensions");
}
// Modify attitude
Rot3 newR = R * Rot3::Expmap(v.head<3>());
// Modify bias
Vector3 newB = b + v.segment<3>(3);
// Modify calibrations
std::vector<Rot3> newS;
for (size_t i = 0; i < S.size(); i++) {
newS.push_back(S[i] * Rot3::Expmap(v.segment<3>(6 + 3*i)));
}
return State(newR, newB, newS);
}
static State identity(int n);
};
@ -254,21 +282,7 @@ Input velocityAction(const G& X, const Input& u);
* @param idx Calibration index
* @return New direction after group action
*/
Vector3 outputAction(const G& X, const Direction& y, int idx);
/**
* Local coordinates assuming xi_0 = identity (Equation 9)
* @param e State representing equivariant error
* @return Local coordinates
*/
Vector local_coords(const State& e);
/**
* Local coordinates inverse assuming xi_0 = identity
* @param eps Local coordinates
* @return Corresponding state
*/
State local_coords_inv(const Vector& eps);
Vector3 outputAction(const G& X, const Unit3& y, int idx);
/**
* Differential of the phi action at E = Id in local coordinates
@ -320,7 +334,7 @@ private:
* @param idx Calibration index
* @return Measurement matrix C0
*/
Matrix measurementMatrixC(const Direction& d, int idx) const;
Matrix measurementMatrixC(const Unit3& d, int idx) const;
/**
* Return the measurement output matrix Dt
@ -460,33 +474,6 @@ Matrix numericalDifferential(std::function<Vector(const Vector&)> f, const Vecto
* direction in 3D space) Uses Unit3's constructor which normalizes vectors
*/
Direction::Direction(const Vector3& d_vec) : d(d_vec) {
if (!checkNorm(d_vec)) {
throw std::invalid_argument("Direction must be a unit vector");
}
}
/** Access the individual components of the direction vector defined above using this methods below
* Uses the Unit3 object from GTSAM to compute the components
*/
double Direction::x() const {
return d.unitVector()[0];
}
double Direction::y() const {
return d.unitVector()[1];
}
double Direction::z() const {
return d.unitVector()[2];
}
bool Direction::hasNaN() const {
Vector3 vec = d.unitVector();
return std::isnan(vec[0]) || std::isnan(vec[1]) || std::isnan(vec[2]);
}
//========================================================================
// Input Class Implementation
//========================================================================
@ -758,14 +745,14 @@ Input velocityAction(const G& X, const Input& u) {
* @return Transformed direction
* Uses Rot3 inverse, matric and Unit3 unitvector functions
*/
Vector3 outputAction(const G& X, const Direction& y, int idx) {
Vector3 outputAction(const G& X, const Unit3& y, int idx) {
if (idx == -1) {
return X.A.inverse().matrix() * y.d.unitVector();
return X.A.inverse().matrix() * y.unitVector();
} else {
if (idx >= static_cast<int>(X.B.size())) {
throw std::out_of_range("Calibration index out of range");
}
return X.B[idx].inverse().matrix() * y.d.unitVector();
return X.B[idx].inverse().matrix() * y.unitVector();
}
}
@ -776,46 +763,46 @@ Vector3 outputAction(const G& X, const Direction& y, int idx) {
* @return Vector with local coordinates
* Uses Rot3 logamo for mapping rotations to the tangent space
*/
Vector local_coords(const State& e) {
if (COORDINATE == "EXPONENTIAL") {
Vector eps(6 + 3 * e.S.size());
// First 3 elements
eps.head<3>() = Rot3::Logmap(e.R);
// Next 3 elements
eps.segment<3>(3) = e.b;
// Remaining elements
for (size_t i = 0; i < e.S.size(); i++) {
eps.segment<3>(6 + 3*i) = Rot3::Logmap(e.S[i]);
}
return eps;
} else if (COORDINATE == "NORMAL") {
throw std::runtime_error("Normal coordinate representation is not implemented yet");
} else {
throw std::invalid_argument("Invalid coordinate representation");
}
}
// Vector local_coords(const State& e) {
// if (COORDINATE == "EXPONENTIAL") {
// Vector eps(6 + 3 * e.S.size());
//
// // First 3 elements
// eps.head<3>() = Rot3::Logmap(e.R);
//
// // Next 3 elements
// eps.segment<3>(3) = e.b;
//
// // Remaining elements
// for (size_t i = 0; i < e.S.size(); i++) {
// eps.segment<3>(6 + 3*i) = Rot3::Logmap(e.S[i]);
// }
//
// return eps;
// } else if (COORDINATE == "NORMAL") {
// throw std::runtime_error("Normal coordinate representation is not implemented yet");
// } else {
// throw std::invalid_argument("Invalid coordinate representation");
// }
// }
/**
* Used to convert the vectorized errors back to state space
* @param eps Local coordinates in the exponential parameterization
* @return State object corresponding to these coordinates
* Uses Rot3 expmap through the G::exp() function
*/
State local_coords_inv(const Vector& eps) {
G X = G::exp(eps);
if (COORDINATE == "EXPONENTIAL") {
std::vector<Rot3> S = X.B;
return State(X.A, eps.segment<3>(3), S);
} else if (COORDINATE == "NORMAL") {
throw std::runtime_error("Normal coordinate representation is not implemented yet");
} else {
throw std::invalid_argument("Invalid coordinate representation");
}
}
// State local_coords_inv(const Vector& eps) {
// G X = G::exp(eps);
//
// if (COORDINATE == "EXPONENTIAL") {
// std::vector<Rot3> S = X.B;
// return State(X.A, eps.segment<3>(3), S);
// } else if (COORDINATE == "NORMAL") {
// throw std::runtime_error("Normal coordinate representation is not implemented yet");
// } else {
// throw std::invalid_argument("Invalid coordinate representation");
// }
// }
/**
* Computes the differential of a state action at the identity of the symmetry
* group
@ -827,7 +814,9 @@ State local_coords_inv(const Vector& eps) {
Matrix stateActionDiff(const State& xi) {
std::function<Vector(const Vector&)> coordsAction =
[&xi](const Vector& U) {
return local_coords(stateAction(G::exp(U), xi));
G groupElement = G::exp(U);
State transformed = stateAction(groupElement, xi);
return xi.localCoordinates(transformed);
};
Vector zeros = Vector::Zero(6 + 3 * xi.S.size());
@ -914,8 +903,8 @@ void EqF::update(const Measurement& y) {
}
// Get vector representations for checking
Vector3 y_vec = y.y.d.unitVector();
Vector3 d_vec = y.d.d.unitVector();
Vector3 y_vec = y.y.unitVector();
Vector3 d_vec = y.d.unitVector();
// Skip update if any NaN values are present
if (std::isnan(y_vec[0]) || std::isnan(y_vec[1]) || std::isnan(y_vec[2]) ||
@ -925,7 +914,7 @@ void EqF::update(const Measurement& y) {
Matrix Ct = measurementMatrixC(y.d, y.cal_idx);
Vector3 action_result = outputAction(X_hat.inv(), y.y, y.cal_idx);
Vector3 delta_vec = Rot3::Hat(y.d.d.unitVector()) * action_result;
Vector3 delta_vec = Rot3::Hat(y.d.unitVector()) * action_result;
Matrix Dt = outputMatrixDt(y.cal_idx);
Matrix S = Ct * Sigma * Ct.transpose() + Dt * y.Sigma * Dt.transpose();
Matrix K = Sigma * Ct.transpose() * S.inverse();
@ -1013,16 +1002,16 @@ Matrix EqF::inputMatrixBt() const {
* @return Measurement matrix
* Uses the matrix zero, Rot3 hat and the Unitvector functions
*/
Matrix EqF::measurementMatrixC(const Direction& d, int idx) const {
Matrix EqF::measurementMatrixC(const Unit3& d, int idx) const {
Matrix Cc = Matrix::Zero(3, 3 * n_cal);
// If the measurement is related to a sensor that has a calibration state
if (idx >= 0) {
// Set the correct 3x3 block in Cc
Cc.block<3, 3>(0, 3 * idx) = Rot3::Hat(d.d.unitVector());
Cc.block<3, 3>(0, 3 * idx) = Rot3::Hat(d.unitVector());
}
Matrix3 wedge_d = Rot3::Hat(d.d.unitVector());
Matrix3 wedge_d = Rot3::Hat(d.unitVector());
// Create the combined matrix
Matrix temp(3, 6 + 3 * n_cal);

View File

@ -262,8 +262,11 @@ void processDataWithEqF(EqF& filter, const std::vector<Data>& data_list, int pri
totalMeasurements++;
// Skip invalid measurements
if (y.y.hasNaN() || y.d.hasNaN()) {
continue;
Vector3 y_vec = y.y.unitVector();
Vector3 d_vec = y.d.unitVector();
if (std::isnan(y_vec[0]) || std::isnan(y_vec[1]) || std::isnan(y_vec[2]) ||
std::isnan(d_vec[0]) || std::isnan(d_vec[1]) || std::isnan(d_vec[2])) {
continue;
}
try {