gtsam/examples/ABC_EQF_Demo.cpp

444 lines
14 KiB
C++

/**
* @file ABC_EQF_Demo.cpp
* @brief Demonstration of the full Attitude-Bias-Calibration Equivariant Filter
*
* This demo shows the Equivariant Filter (EqF) for attitude estimation
* with both gyroscope bias and sensor extrinsic calibration, based on the
* paper: "Overcoming Bias: Equivariant Filter Design for Biased Attitude
* Estimation with Online Calibration" by Fornasier et al. Authors: Darshan
* Rajasekaran & Jennifer Oum
*/
#include "ABC_EQF.h"
// Use namespace for convenience
using namespace gtsam;
constexpr size_t N = 1; // Number of calibration states
using M = abc_eqf_lib::State<N>;
using Group = abc_eqf_lib::G<N>;
using EqFilter = abc_eqf_lib::EqF<N>;
using gtsam::abc_eqf_lib::EqF;
using gtsam::abc_eqf_lib::Input;
using gtsam::abc_eqf_lib::Measurement;
/// Data structure for ground-truth, input and output data
struct Data {
M xi; /// Ground-truth state
Input u; /// Input measurements
std::vector<Measurement> y; /// Output measurements
int n_meas; /// Number of measurements
double t; /// Time
double dt; /// Time step
};
//========================================================================
// Data Processing Functions
//========================================================================
/**
* Load data from CSV file into a vector of Data objects for the EqF
*
* CSV format:
* - t: Time
* - q_w, q_x, q_y, q_z: True attitude quaternion
* - b_x, b_y, b_z: True bias
* - cq_w_0, cq_x_0, cq_y_0, cq_z_0: True calibration quaternion
* - w_x, w_y, w_z: Angular velocity measurements
* - std_w_x, std_w_y, std_w_z: Angular velocity measurement standard deviations
* - std_b_x, std_b_y, std_b_z: Bias process noise standard deviations
* - y_x_0, y_y_0, y_z_0, y_x_1, y_y_1, y_z_1: Direction measurements
* - std_y_x_0, std_y_y_0, std_y_z_0, std_y_x_1, std_y_y_1, std_y_z_1: Direction
* measurement standard deviations
* - d_x_0, d_y_0, d_z_0, d_x_1, d_y_1, d_z_1: Reference directions
*
*/
std::vector<Data> loadDataFromCSV(const std::string& filename, int startRow = 0,
int maxRows = -1, int downsample = 1);
/// Process data with EqF and print summary results
void processDataWithEqF(EqFilter& filter, const std::vector<Data>& data_list,
int printInterval = 10);
//========================================================================
// Data Processing Functions Implementation
//========================================================================
/*
* Loads the test data from the csv file
* startRow First row to load based on csv, 0 by default
* maxRows maximum rows to load, defaults to all rows
* downsample Downsample factor, default 1
* A list of data objects
*/
std::vector<Data> loadDataFromCSV(const std::string& filename, int startRow,
int maxRows, int downsample) {
std::vector<Data> data_list;
std::ifstream file(filename);
if (!file.is_open()) {
throw std::runtime_error("Failed to open file: " + filename);
}
std::cout << "Loading data from " << filename << "..." << std::flush;
std::string line;
int lineNumber = 0;
int rowCount = 0;
int errorCount = 0;
double prevTime = 0.0;
// Skip header
std::getline(file, line);
lineNumber++;
// Skip to startRow
while (lineNumber < startRow && std::getline(file, line)) {
lineNumber++;
}
// Read data
while (std::getline(file, line) && (maxRows == -1 || rowCount < maxRows)) {
lineNumber++;
// Apply downsampling
if ((lineNumber - startRow - 1) % downsample != 0) {
continue;
}
std::istringstream ss(line);
std::string token;
std::vector<double> values;
// Parse line into values
while (std::getline(ss, token, ',')) {
try {
values.push_back(std::stod(token));
} catch (const std::exception& e) {
errorCount++;
values.push_back(0.0); // Use default value
}
}
// Check if we have enough values
if (values.size() < 39) {
errorCount++;
continue;
}
// Extract values
double t = values[0];
double dt = (rowCount == 0) ? 0.0 : t - prevTime;
prevTime = t;
// Create ground truth state
Quaternion quat(values[1], values[2], values[3], values[4]); // w, x, y, z
Rot3 R = Rot3(quat);
Vector3 b(values[5], values[6], values[7]);
Quaternion calQuat(values[8], values[9], values[10],
values[11]); // w, x, y, z
std::array<Rot3, N> S = {Rot3(calQuat)};
M xi(R, b, S);
// Create input
Vector3 w(values[12], values[13], values[14]);
// Create input covariance matrix (6x6)
// First 3x3 block for angular velocity, second 3x3 block for bias process
// noise
Matrix inputCov = Matrix::Zero(6, 6);
inputCov(0, 0) = values[15] * values[15]; // std_w_x^2
inputCov(1, 1) = values[16] * values[16]; // std_w_y^2
inputCov(2, 2) = values[17] * values[17]; // std_w_z^2
inputCov(3, 3) = values[18] * values[18]; // std_b_x^2
inputCov(4, 4) = values[19] * values[19]; // std_b_y^2
inputCov(5, 5) = values[20] * values[20]; // std_b_z^2
Input u{w, inputCov};
// Create measurements
std::vector<Measurement> measurements;
// First measurement (calibrated sensor, cal_idx = 0)
Vector3 y0(values[21], values[22], values[23]);
Vector3 d0(values[33], values[34], values[35]);
// Normalize vectors if needed
if (abs(y0.norm() - 1.0) > 1e-5) y0.normalize();
if (abs(d0.norm() - 1.0) > 1e-5) d0.normalize();
// Measurement covariance
Matrix3 covY0 = Matrix3::Zero();
covY0(0, 0) = values[27] * values[27]; // std_y_x_0^2
covY0(1, 1) = values[28] * values[28]; // std_y_y_0^2
covY0(2, 2) = values[29] * values[29]; // std_y_z_0^2
// Create measurement
measurements.push_back(Measurement{Unit3(y0), Unit3(d0), covY0, 0});
// Second measurement (calibrated sensor, cal_idx = -1)
Vector3 y1(values[24], values[25], values[26]);
Vector3 d1(values[36], values[37], values[38]);
// Normalize vectors if needed
if (abs(y1.norm() - 1.0) > 1e-5) y1.normalize();
if (abs(d1.norm() - 1.0) > 1e-5) d1.normalize();
// Measurement covariance
Matrix3 covY1 = Matrix3::Zero();
covY1(0, 0) = values[30] * values[30]; // std_y_x_1^2
covY1(1, 1) = values[31] * values[31]; // std_y_y_1^2
covY1(2, 2) = values[32] * values[32]; // std_y_z_1^2
// Create measurement
measurements.push_back(Measurement{Unit3(y1), Unit3(d1), covY1, -1});
// Create Data object and add to list
data_list.push_back(Data{xi, u, measurements, 2, t, dt});
rowCount++;
// Show loading progress every 1000 rows
if (rowCount % 1000 == 0) {
std::cout << "." << std::flush;
}
}
std::cout << " Done!" << std::endl;
std::cout << "Loaded " << data_list.size() << " data points";
if (errorCount > 0) {
std::cout << " (" << errorCount << " errors encountered)";
}
std::cout << std::endl;
return data_list;
}
/// Takes in the data and runs an EqF on it and reports the results
void processDataWithEqF(EqFilter& filter, const std::vector<Data>& data_list,
int printInterval) {
if (data_list.empty()) {
std::cerr << "No data to process" << std::endl;
return;
}
std::cout << "Processing " << data_list.size() << " data points with EqF..."
<< std::endl;
// Track performance metrics
std::vector<double> att_errors;
std::vector<double> bias_errors;
std::vector<double> cal_errors;
// Track time for performance measurement
auto start = std::chrono::high_resolution_clock::now();
int totalMeasurements = 0;
int validMeasurements = 0;
// Define constant for converting radians to degrees
const double RAD_TO_DEG = 180.0 / M_PI;
// Print a progress indicator
int progressStep = data_list.size() / 10; // 10 progress updates
if (progressStep < 1) progressStep = 1;
std::cout << "Progress: ";
for (size_t i = 0; i < data_list.size(); i++) {
const Data& data = data_list[i];
// Propagate filter with current input and time step
filter.propagation(data.u, data.dt);
// Process all measurements
for (const auto& y : data.y) {
totalMeasurements++;
// Skip invalid measurements
Vector3 y_vec = y.y.unitVector();
Vector3 d_vec = y.d.unitVector();
if (std::isnan(y_vec[0]) || std::isnan(y_vec[1]) ||
std::isnan(y_vec[2]) || std::isnan(d_vec[0]) ||
std::isnan(d_vec[1]) || std::isnan(d_vec[2])) {
continue;
}
try {
filter.update(y);
validMeasurements++;
} catch (const std::exception& e) {
std::cerr << "Error updating at t=" << data.t << ": " << e.what()
<< std::endl;
}
}
// Get current state estimate
M estimate = filter.stateEstimate();
// Calculate errors
Vector3 att_error = Rot3::Logmap(data.xi.R.between(estimate.R));
Vector3 bias_error = estimate.b - data.xi.b;
Vector3 cal_error = Vector3::Zero();
if (!data.xi.S.empty() && !estimate.S.empty()) {
cal_error = Rot3::Logmap(data.xi.S[0].between(estimate.S[0]));
}
// Store errors
att_errors.push_back(att_error.norm());
bias_errors.push_back(bias_error.norm());
cal_errors.push_back(cal_error.norm());
// Show progress dots
if (i % progressStep == 0) {
std::cout << "." << std::flush;
}
}
std::cout << " Done!" << std::endl;
auto end = std::chrono::high_resolution_clock::now();
std::chrono::duration<double> elapsed = end - start;
// Calculate average errors
double avg_att_error = 0.0;
double avg_bias_error = 0.0;
double avg_cal_error = 0.0;
if (!att_errors.empty()) {
avg_att_error = std::accumulate(att_errors.begin(), att_errors.end(), 0.0) /
att_errors.size();
avg_bias_error =
std::accumulate(bias_errors.begin(), bias_errors.end(), 0.0) /
bias_errors.size();
avg_cal_error = std::accumulate(cal_errors.begin(), cal_errors.end(), 0.0) /
cal_errors.size();
}
// Calculate final errors from last data point
const Data& final_data = data_list.back();
M final_estimate = filter.stateEstimate();
Vector3 final_att_error =
Rot3::Logmap(final_data.xi.R.between(final_estimate.R));
Vector3 final_bias_error = final_estimate.b - final_data.xi.b;
Vector3 final_cal_error = Vector3::Zero();
if (!final_data.xi.S.empty() && !final_estimate.S.empty()) {
final_cal_error =
Rot3::Logmap(final_data.xi.S[0].between(final_estimate.S[0]));
}
// Print summary statistics
std::cout << "\n=== Filter Performance Summary ===" << std::endl;
std::cout << "Processing time: " << elapsed.count() << " seconds"
<< std::endl;
std::cout << "Processed measurements: " << totalMeasurements
<< " (valid: " << validMeasurements << ")" << std::endl;
// Average errors
std::cout << "\n-- Average Errors --" << std::endl;
std::cout << "Attitude: " << (avg_att_error * RAD_TO_DEG) << "°" << std::endl;
std::cout << "Bias: " << avg_bias_error << std::endl;
std::cout << "Calibration: " << (avg_cal_error * RAD_TO_DEG) << "°"
<< std::endl;
// Final errors
std::cout << "\n-- Final Errors --" << std::endl;
std::cout << "Attitude: " << (final_att_error.norm() * RAD_TO_DEG) << "°"
<< std::endl;
std::cout << "Bias: " << final_bias_error.norm() << std::endl;
std::cout << "Calibration: " << (final_cal_error.norm() * RAD_TO_DEG) << "°"
<< std::endl;
// Print a brief comparison of final estimate vs ground truth
std::cout << "\n-- Final State vs Ground Truth --" << std::endl;
std::cout << "Attitude (RPY) - Estimate: "
<< (final_estimate.R.rpy() * RAD_TO_DEG).transpose()
<< "° | Truth: " << (final_data.xi.R.rpy() * RAD_TO_DEG).transpose()
<< "°" << std::endl;
std::cout << "Bias - Estimate: " << final_estimate.b.transpose()
<< " | Truth: " << final_data.xi.b.transpose() << std::endl;
if (!final_estimate.S.empty() && !final_data.xi.S.empty()) {
std::cout << "Calibration (RPY) - Estimate: "
<< (final_estimate.S[0].rpy() * RAD_TO_DEG).transpose()
<< "° | Truth: "
<< (final_data.xi.S[0].rpy() * RAD_TO_DEG).transpose() << "°"
<< std::endl;
}
}
int main(int argc, char* argv[]) {
std::cout << "ABC-EqF: Attitude-Bias-Calibration Equivariant Filter Demo"
<< std::endl;
std::cout << "=============================================================="
<< std::endl;
try {
// Parse command line options
std::string csvFilePath;
int maxRows = -1; // Process all rows by default
int downsample = 1; // No downsampling by default
if (argc > 1) {
csvFilePath = argv[1];
} else {
// Try to find the EQFdata file in the GTSAM examples directory
try {
csvFilePath = findExampleDataFile("EqFdata.csv");
} catch (const std::exception& e) {
std::cerr << "Error: Could not find EqFdata.csv" << std::endl;
std::cerr << "Usage: " << argv[0]
<< " [csv_file_path] [max_rows] [downsample]" << std::endl;
return 1;
}
}
// Optional command line parameters
if (argc > 2) {
maxRows = std::stoi(argv[2]);
}
if (argc > 3) {
downsample = std::stoi(argv[3]);
}
// Load data from CSV file
std::vector<Data> data =
loadDataFromCSV(csvFilePath, 0, maxRows, downsample);
if (data.empty()) {
std::cerr << "No data available to process. Exiting." << std::endl;
return 1;
}
// Initialize the EqF filter with one calibration state
int n_sensors = 2;
// Initial covariance - larger values allow faster convergence
Matrix initialSigma = Matrix::Identity(6 + 3 * N, 6 + 3 * N);
initialSigma.diagonal().head<3>() =
Vector3::Constant(0.1); // Attitude uncertainty
initialSigma.diagonal().segment<3>(3) =
Vector3::Constant(0.01); // Bias uncertainty
initialSigma.diagonal().tail<3>() =
Vector3::Constant(0.1); // Calibration uncertainty
// Create filter
EqFilter filter(initialSigma, n_sensors);
// Process data
processDataWithEqF(filter, data);
} catch (const std::exception& e) {
std::cerr << "Error: " << e.what() << std::endl;
return 1;
}
std::cout << "\nEqF demonstration completed successfully." << std::endl;
return 0;
}