Median/plotting/initF

release/4.3a0
Frank Dellaert 2024-10-29 11:56:07 -07:00
parent c68858d7b6
commit 8af0465d92
1 changed files with 28 additions and 21 deletions

View File

@ -32,7 +32,7 @@ from gtsam import (
K = gtsam.symbol_shorthand.K K = gtsam.symbol_shorthand.K
# Methods to compare # Methods to compare
methods = ["Fundamental", "SimpleF", "Essential+Ks", "Calibrated"] methods = ["SimpleF", "Fundamental", "Essential+Ks", "Calibrated"]
# Formatter function for printing keys # Formatter function for printing keys
@ -253,13 +253,14 @@ def plot_results(results):
color = "tab:red" color = "tab:red"
ax1.set_xlabel("Method") ax1.set_xlabel("Method")
ax1.set_ylabel("Final Error", color=color) ax1.set_ylabel("Median Error (log scale)", color=color)
ax1.set_yscale("log")
ax1.bar(methods, final_errors, color=color, alpha=0.6) ax1.bar(methods, final_errors, color=color, alpha=0.6)
ax1.tick_params(axis="y", labelcolor=color) ax1.tick_params(axis="y", labelcolor=color)
ax2 = ax1.twinx() ax2 = ax1.twinx()
color = "tab:blue" color = "tab:blue"
ax2.set_ylabel("Mean Geodesic Distance", color=color) ax2.set_ylabel("Median Geodesic Distance", color=color)
ax2.plot(methods, distances, color=color, marker="o", linestyle="-") ax2.plot(methods, distances, color=color, marker="o", linestyle="-")
ax2.tick_params(axis="y", labelcolor=color) ax2.tick_params(axis="y", labelcolor=color)
@ -278,14 +279,13 @@ def plot_results(results):
fig.tight_layout() fig.tight_layout()
plt.show() plt.show()
# Main function # Main function
def main(): def main():
# Parse command line arguments # Parse command line arguments
parser = argparse.ArgumentParser(description="Compare Fundamental and Essential Matrix Methods") parser = argparse.ArgumentParser(description="Compare Fundamental and Essential Matrix Methods")
parser.add_argument("--num_cameras", type=int, default=4, help="Number of cameras (default: 4)") parser.add_argument("--num_cameras", type=int, default=4, help="Number of cameras (default: 4)")
parser.add_argument("--num_extra_points", type=int, default=12, help="Number of extra random points (default: 12)") parser.add_argument("--num_extra_points", type=int, default=12, help="Number of extra random points (default: 12)")
parser.add_argument("--nr_trials", type=int, default=5, help="Number of trials (default: 5)") parser.add_argument("--num_trials", type=int, default=5, help="Number of trials (default: 5)")
parser.add_argument("--seed", type=int, default=42, help="Random seed (default: 42)") parser.add_argument("--seed", type=int, default=42, help="Random seed (default: 42)")
parser.add_argument("--noise_std", type=float, default=0.5, help="Standard deviation of noise (default: 0.5)") parser.add_argument("--noise_std", type=float, default=0.5, help="Standard deviation of noise (default: 0.5)")
args = parser.parse_args() args = parser.parse_args()
@ -303,12 +303,13 @@ def main():
ground_truth = {method: compute_ground_truth(method, poses, cal) for method in methods} ground_truth = {method: compute_ground_truth(method, poses, cal) for method in methods}
# Get initial estimates # Get initial estimates
initial_estimate = { initial_estimate: dict[Values] = {
method: get_initial_estimate(method, args.num_cameras, ground_truth[method], cal) for method in methods method: get_initial_estimate(method, args.num_cameras, ground_truth[method], cal) for method in methods
} }
simple_f_result: Values = Values()
for trial in range(args.nr_trials): for trial in range(args.num_trials):
print(f"\nTrial {trial + 1}/{args.nr_trials}") print(f"\nTrial {trial + 1}/{args.num_trials}")
# Simulate data # Simulate data
measurements = simulate_data(points, poses, cal, rng, args.noise_std) measurements = simulate_data(points, poses, cal, rng, args.noise_std)
@ -319,20 +320,26 @@ def main():
# Build the factor graph # Build the factor graph
graph = build_factor_graph(method, args.num_cameras, measurements, cal) graph = build_factor_graph(method, args.num_cameras, measurements, cal)
# Assert that the initial error is the same for all methods: # For F, initialize from SimpleF:
if method == methods[0]: if method == "Fundamental":
error0 = graph.error(initial_estimate[method]) initial_estimate[method] = simple_f_result
elif method == "Calibrated":
current_error = graph.error(initial_estimate[method]) * cal.f() * cal.f()
print(error0, current_error)
assert np.allclose(error0, current_error), "Initial errors do not match among methods."
else:
current_error = graph.error(initial_estimate[method])
assert np.allclose(error0, current_error), "Initial errors do not match among methods."
# Optimize the graph # Optimize the graph
result, iterations = optimize(graph, initial_estimate[method], method) result, iterations = optimize(graph, initial_estimate[method], method)
# Store SimpleF result as a set of FundamentalMatrices
if method == "SimpleF":
simple_f_result = Values()
for a in range(args.num_cameras):
b = (a + 1) % args.num_cameras # Next camera
c = (a + 2) % args.num_cameras # Camera after next
key_ab = EdgeKey(a, b).key()
key_ac = EdgeKey(a, c).key()
F1 = result.atSimpleFundamentalMatrix(key_ab).matrix()
F2 = result.atSimpleFundamentalMatrix(key_ac).matrix()
simple_f_result.insert(key_ab, FundamentalMatrix(F1))
simple_f_result.insert(key_ac, FundamentalMatrix(F2))
# Compute distances from ground truth # Compute distances from ground truth
distances = compute_distances(method, result, ground_truth, args.num_cameras, cal) distances = compute_distances(method, result, ground_truth, args.num_cameras, cal)
@ -353,9 +360,9 @@ def main():
# Average results over trials # Average results over trials
for method in methods: for method in methods:
results[method]["final_error"] = np.mean(results[method]["final_error"]) results[method]["final_error"] = np.median(results[method]["final_error"])
results[method]["distances"] = np.mean(results[method]["distances"]) results[method]["distances"] = np.median(results[method]["distances"])
results[method]["iterations"] = np.mean(results[method]["iterations"]) results[method]["iterations"] = np.median(results[method]["iterations"])
# Plot results # Plot results
plot_results(results) plot_results(results)