/** * @file ABC_EQF_Demo.cpp * @brief Demonstration of the full Attitude-Bias-Calibration Equivariant Filter * * This demo shows the Equivariant Filter (EqF) for attitude estimation * with both gyroscope bias and sensor extrinsic calibration, based on the paper: * "Overcoming Bias: Equivariant Filter Design for Biased Attitude Estimation * with Online Calibration" by Fornasier et al. * Authors: Darshan Rajasekaran & Jennifer Oum */ #include "ABC_EQF.h" // Use namespace for convenience using namespace abc_eqf_lib; using namespace gtsam; //======================================================================== // Data Processing Functions //======================================================================== /** * Load data from CSV file into a vector of Data objects for the EqF * * CSV format: * - t: Time * - q_w, q_x, q_y, q_z: True attitude quaternion * - b_x, b_y, b_z: True bias * - cq_w_0, cq_x_0, cq_y_0, cq_z_0: True calibration quaternion * - w_x, w_y, w_z: Angular velocity measurements * - std_w_x, std_w_y, std_w_z: Angular velocity measurement standard deviations * - std_b_x, std_b_y, std_b_z: Bias process noise standard deviations * - y_x_0, y_y_0, y_z_0, y_x_1, y_y_1, y_z_1: Direction measurements * - std_y_x_0, std_y_y_0, std_y_z_0, std_y_x_1, std_y_y_1, std_y_z_1: Direction measurement standard deviations * - d_x_0, d_y_0, d_z_0, d_x_1, d_y_1, d_z_1: Reference directions * * @param filename Path to the CSV file * @param startRow First row to load (default: 0) * @param maxRows Maximum number of rows to load (default: all) * @param downsample Downsample factor (default: 1, which means no downsampling) * @return Vector of Data objects */ std::vector loadDataFromCSV(const std::string& filename, int startRow = 0, int maxRows = -1, int downsample = 1); /** * Process data with EqF and print summary results * @param filter Initialized EqF filter * @param data_list Vector of Data objects to process * @param printInterval Progress indicator interval (used internally) */ void processDataWithEqF(EqF& filter, const std::vector& data_list, int printInterval = 10); //======================================================================== // Data Processing Functions Implementation //======================================================================== /** * @brief Loads the test data from the csv file * @param filename path to the csv file is specified * @param startRow First row to load based on csv, 0 by default * @param maxRows maximum rows to load, defaults to all rows * @param downsample Downsample factor, default 1 * @return A list of data objects */ std::vector loadDataFromCSV(const std::string& filename, int startRow, int maxRows, int downsample) { std::vector data_list; std::ifstream file(filename); if (!file.is_open()) { throw std::runtime_error("Failed to open file: " + filename); } std::cout << "Loading data from " << filename << "..." << std::flush; std::string line; int lineNumber = 0; int rowCount = 0; int errorCount = 0; double prevTime = 0.0; // Skip header std::getline(file, line); lineNumber++; // Skip to startRow while (lineNumber < startRow && std::getline(file, line)) { lineNumber++; } // Read data while (std::getline(file, line) && (maxRows == -1 || rowCount < maxRows)) { lineNumber++; // Apply downsampling if ((lineNumber - startRow - 1) % downsample != 0) { continue; } std::istringstream ss(line); std::string token; std::vector values; // Parse line into values while (std::getline(ss, token, ',')) { try { values.push_back(std::stod(token)); } catch (const std::exception& e) { errorCount++; values.push_back(0.0); // Use default value } } // Check if we have enough values if (values.size() < 39) { errorCount++; continue; } // Extract values double t = values[0]; double dt = (rowCount == 0) ? 0.0 : t - prevTime; prevTime = t; // Create ground truth state Quaternion quat(values[1], values[2], values[3], values[4]); // w, x, y, z Rot3 R = Rot3(quat); Vector3 b(values[5], values[6], values[7]); Quaternion calQuat(values[8], values[9], values[10], values[11]); // w, x, y, z std::vector S = {Rot3(calQuat)}; State xi(R, b, S); // Create input Vector3 w(values[12], values[13], values[14]); // Create input covariance matrix (6x6) // First 3x3 block for angular velocity, second 3x3 block for bias process noise Matrix inputCov = Matrix::Zero(6, 6); inputCov(0, 0) = values[15] * values[15]; // std_w_x^2 inputCov(1, 1) = values[16] * values[16]; // std_w_y^2 inputCov(2, 2) = values[17] * values[17]; // std_w_z^2 inputCov(3, 3) = values[18] * values[18]; // std_b_x^2 inputCov(4, 4) = values[19] * values[19]; // std_b_y^2 inputCov(5, 5) = values[20] * values[20]; // std_b_z^2 Input u{w, inputCov}; // Create measurements std::vector measurements; // First measurement (calibrated sensor, cal_idx = 0) Vector3 y0(values[21], values[22], values[23]); Vector3 d0(values[33], values[34], values[35]); // Normalize vectors if needed if (abs(y0.norm() - 1.0) > 1e-5) y0.normalize(); if (abs(d0.norm() - 1.0) > 1e-5) d0.normalize(); // Measurement covariance Matrix3 covY0 = Matrix3::Zero(); covY0(0, 0) = values[27] * values[27]; // std_y_x_0^2 covY0(1, 1) = values[28] * values[28]; // std_y_y_0^2 covY0(2, 2) = values[29] * values[29]; // std_y_z_0^2 // Create measurement measurements.push_back(Measurement(y0, d0, covY0, 0)); // Second measurement (calibrated sensor, cal_idx = -1) Vector3 y1(values[24], values[25], values[26]); Vector3 d1(values[36], values[37], values[38]); // Normalize vectors if needed if (abs(y1.norm() - 1.0) > 1e-5) y1.normalize(); if (abs(d1.norm() - 1.0) > 1e-5) d1.normalize(); // Measurement covariance Matrix3 covY1 = Matrix3::Zero(); covY1(0, 0) = values[30] * values[30]; // std_y_x_1^2 covY1(1, 1) = values[31] * values[31]; // std_y_y_1^2 covY1(2, 2) = values[32] * values[32]; // std_y_z_1^2 // Create measurement measurements.push_back(Measurement(y1, d1, covY1, -1)); // Create Data object and add to list data_list.push_back(Data(xi, 1, u, measurements, 2, t, dt)); rowCount++; // Show loading progress every 1000 rows if (rowCount % 1000 == 0) { std::cout << "." << std::flush; } } std::cout << " Done!" << std::endl; std::cout << "Loaded " << data_list.size() << " data points"; if (errorCount > 0) { std::cout << " (" << errorCount << " errors encountered)"; } std::cout << std::endl; return data_list; } /** * @brief Takes in the data and runs an EqF on it and reports the results * @param filter Initialized EqF filter * @param data_list std::vector * @param printInterval Progress indicator * Prints the performance statstics like average error etc * Uses Rot3 between, logmap and rpy functions */ void processDataWithEqF(EqF& filter, const std::vector& data_list, int printInterval) { if (data_list.empty()) { std::cerr << "No data to process" << std::endl; return; } std::cout << "Processing " << data_list.size() << " data points with EqF..." << std::endl; // Track performance metrics std::vector att_errors; std::vector bias_errors; std::vector cal_errors; // Track time for performance measurement auto start = std::chrono::high_resolution_clock::now(); int totalMeasurements = 0; int validMeasurements = 0; // Define constant for converting radians to degrees const double RAD_TO_DEG = 180.0 / M_PI; // Print a progress indicator int progressStep = data_list.size() / 10; // 10 progress updates if (progressStep < 1) progressStep = 1; std::cout << "Progress: "; for (size_t i = 0; i < data_list.size(); i++) { const Data& data = data_list[i]; // Propagate filter with current input and time step filter.propagation(data.u, data.dt); // Process all measurements for (const auto& y : data.y) { totalMeasurements++; // Skip invalid measurements Vector3 y_vec = y.y.unitVector(); Vector3 d_vec = y.d.unitVector(); if (std::isnan(y_vec[0]) || std::isnan(y_vec[1]) || std::isnan(y_vec[2]) || std::isnan(d_vec[0]) || std::isnan(d_vec[1]) || std::isnan(d_vec[2])) { continue; } try { filter.update(y); validMeasurements++; } catch (const std::exception& e) { std::cerr << "Error updating at t=" << data.t << ": " << e.what() << std::endl; } } // Get current state estimate State estimate = filter.stateEstimate(); // Calculate errors Vector3 att_error = Rot3::Logmap(data.xi.R.between(estimate.R)); Vector3 bias_error = estimate.b - data.xi.b; Vector3 cal_error = Vector3::Zero(); if (!data.xi.S.empty() && !estimate.S.empty()) { cal_error = Rot3::Logmap(data.xi.S[0].between(estimate.S[0])); } // Store errors att_errors.push_back(att_error.norm()); bias_errors.push_back(bias_error.norm()); cal_errors.push_back(cal_error.norm()); // Show progress dots if (i % progressStep == 0) { std::cout << "." << std::flush; } } std::cout << " Done!" << std::endl; auto end = std::chrono::high_resolution_clock::now(); std::chrono::duration elapsed = end - start; // Calculate average errors double avg_att_error = 0.0; double avg_bias_error = 0.0; double avg_cal_error = 0.0; if (!att_errors.empty()) { avg_att_error = std::accumulate(att_errors.begin(), att_errors.end(), 0.0) / att_errors.size(); avg_bias_error = std::accumulate(bias_errors.begin(), bias_errors.end(), 0.0) / bias_errors.size(); avg_cal_error = std::accumulate(cal_errors.begin(), cal_errors.end(), 0.0) / cal_errors.size(); } // Calculate final errors from last data point const Data& final_data = data_list.back(); State final_estimate = filter.stateEstimate(); Vector3 final_att_error = Rot3::Logmap(final_data.xi.R.between(final_estimate.R)); Vector3 final_bias_error = final_estimate.b - final_data.xi.b; Vector3 final_cal_error = Vector3::Zero(); if (!final_data.xi.S.empty() && !final_estimate.S.empty()) { final_cal_error = Rot3::Logmap(final_data.xi.S[0].between(final_estimate.S[0])); } // Print summary statistics std::cout << "\n=== Filter Performance Summary ===" << std::endl; std::cout << "Processing time: " << elapsed.count() << " seconds" << std::endl; std::cout << "Processed measurements: " << totalMeasurements << " (valid: " << validMeasurements << ")" << std::endl; // Average errors std::cout << "\n-- Average Errors --" << std::endl; std::cout << "Attitude: " << (avg_att_error * RAD_TO_DEG) << "°" << std::endl; std::cout << "Bias: " << avg_bias_error << std::endl; std::cout << "Calibration: " << (avg_cal_error * RAD_TO_DEG) << "°" << std::endl; // Final errors std::cout << "\n-- Final Errors --" << std::endl; std::cout << "Attitude: " << (final_att_error.norm() * RAD_TO_DEG) << "°" << std::endl; std::cout << "Bias: " << final_bias_error.norm() << std::endl; std::cout << "Calibration: " << (final_cal_error.norm() * RAD_TO_DEG) << "°" << std::endl; // Print a brief comparison of final estimate vs ground truth std::cout << "\n-- Final State vs Ground Truth --" << std::endl; std::cout << "Attitude (RPY) - Estimate: " << (final_estimate.R.rpy() * RAD_TO_DEG).transpose() << "° | Truth: " << (final_data.xi.R.rpy() * RAD_TO_DEG).transpose() << "°" << std::endl; std::cout << "Bias - Estimate: " << final_estimate.b.transpose() << " | Truth: " << final_data.xi.b.transpose() << std::endl; if (!final_estimate.S.empty() && !final_data.xi.S.empty()) { std::cout << "Calibration (RPY) - Estimate: " << (final_estimate.S[0].rpy() * RAD_TO_DEG).transpose() << "° | Truth: " << (final_data.xi.S[0].rpy() * RAD_TO_DEG).transpose() << "°" << std::endl; } } /** * Main function for the EqF demo * @param argc Number of arguments * @param argv Array of arguments * @return Exit code */ int main(int argc, char* argv[]) { std::cout << "ABC-EqF: Attitude-Bias-Calibration Equivariant Filter Demo" << std::endl; std::cout << "==============================================================" << std::endl; try { // Parse command line options std::string csvFilePath; int maxRows = -1; // Process all rows by default int downsample = 1; // No downsampling by default if (argc > 1) { csvFilePath = argv[1]; } else { // Try to find the EQFdata file in the GTSAM examples directory try { csvFilePath = findExampleDataFile("EqFdata.csv"); } catch (const std::exception& e) { std::cerr << "Error: Could not find EqFdata.csv" << std::endl; std::cerr << "Usage: " << argv[0] << " [csv_file_path] [max_rows] [downsample]" << std::endl; return 1; } } // Optional command line parameters if (argc > 2) { maxRows = std::stoi(argv[2]); } if (argc > 3) { downsample = std::stoi(argv[3]); } // Load data from CSV file std::vector data = loadDataFromCSV(csvFilePath, 0, maxRows, downsample); if (data.empty()) { std::cerr << "No data available to process. Exiting." << std::endl; return 1; } // Initialize the EqF filter with one calibration state int n_cal = 1; int n_sensors = 2; // Initial covariance - larger values allow faster convergence Matrix initialSigma = Matrix::Identity(6 + 3*n_cal, 6 + 3*n_cal); initialSigma.diagonal().head<3>() = Vector3::Constant(0.1); // Attitude uncertainty initialSigma.diagonal().segment<3>(3) = Vector3::Constant(0.01); // Bias uncertainty initialSigma.diagonal().tail<3>() = Vector3::Constant(0.1); // Calibration uncertainty // Create filter EqF filter(initialSigma, n_cal, n_sensors); // Process data processDataWithEqF(filter, data); } catch (const std::exception& e) { std::cerr << "Error: " << e.what() << std::endl; return 1; } std::cout << "\nEqF demonstration completed successfully." << std::endl; return 0; }