better filenaming and improved plotting

release/4.3a0
Varun Agrawal 2025-02-20 13:01:04 -05:00
parent 2062e0124e
commit 518b067104
1 changed files with 26 additions and 10 deletions

View File

@ -117,24 +117,33 @@ class City10000Dataset:
def plot_all_results(ground_truth, def plot_all_results(ground_truth,
all_results, all_results,
iters=0,
estimate_color=(0.1, 0.1, 0.9, 0.4), estimate_color=(0.1, 0.1, 0.9, 0.4),
estimate_label="Hybrid Factor Graphs", estimate_label="Hybrid Factor Graphs",
text=""): text="",
filename="city10000_results.svg"):
"""Plot the City10000 estimates against the ground truth. """Plot the City10000 estimates against the ground truth.
Args: Args:
ground_truth: The ground truth trajectory as xy values. ground_truth: The ground truth trajectory as xy values.
all_results (List[Tuple(np.ndarray, str)]): All the estimates trajectory as xy values, as well as assginment strings. all_results (List[Tuple(np.ndarray, str)]): All the estimates trajectory as xy values,
as well as assginment strings.
estimate_color (tuple, optional): The color to use for the graph of estimates. estimate_color (tuple, optional): The color to use for the graph of estimates.
Defaults to (0.1, 0.1, 0.9, 0.4). Defaults to (0.1, 0.1, 0.9, 0.4).
estimate_label (str, optional): Label for the estimates, used in the legend. estimate_label (str, optional): Label for the estimates, used in the legend.
Defaults to "Hybrid Factor Graphs". Defaults to "Hybrid Factor Graphs".
""" """
if len(all_results) == 1:
fig, axes = plt.subplots(1, 1)
axes = [axes]
else:
fig, axes = plt.subplots(int(np.ceil(len(all_results) / 2)), 2) fig, axes = plt.subplots(int(np.ceil(len(all_results) / 2)), 2)
axes = axes.flatten()
for i, (estimates, s) in enumerate(all_results): for i, (estimates, s) in enumerate(all_results):
ax = axes[i] ax = axes[i]
ax.axis('equal') ax.axis('equal')
ax.axis((-75.0, 75.0, -75.0, 75.0)) ax.axis((-75.0, 100.0, -75.0, 75.0))
gt = ground_truth[:estimates.shape[0]] gt = ground_truth[:estimates.shape[0]]
ax.plot(gt[:, 0], ax.plot(gt[:, 0],
@ -149,14 +158,17 @@ def plot_all_results(ground_truth,
linewidth=1, linewidth=1,
color=estimate_color, color=estimate_color,
label=estimate_label) label=estimate_label)
ax.legend() # ax.legend()
ax.text(0.0, 100.0, s) # Plot text `s` at (x, y) on axis
ax.text(-60.0, 60.0, s)
fig.suptitle(f"After {iters} iterations")
num_chunks = int(np.ceil(len(text) / 90)) num_chunks = int(np.ceil(len(text) / 90))
text = "\n".join(text[i * 60:(i + 1) * 60] for i in range(num_chunks)) text = "\n".join(text[i * 60:(i + 1) * 60] for i in range(num_chunks))
fig.text(0.0, 0.015, s=text) fig.text(0.0, 0.015, s=text)
fig.savefig("city10000_results.svg", format="svg") fig.savefig(filename, format="svg")
class Experiment: class Experiment:
@ -362,9 +374,9 @@ class Experiment:
# TODO Get cardinality from DiscreteFactor # TODO Get cardinality from DiscreteFactor
discrete_keys.push_back((key, 2)) discrete_keys.push_back((key, 2))
print("plotting all hypotheses") print("plotting all hypotheses")
self.plot_all_hypotheses(discrete_keys, key_t + 1) self.plot_all_hypotheses(discrete_keys, key_t + 1, index)
def plot_all_hypotheses(self, discrete_keys, num_poses): def plot_all_hypotheses(self, discrete_keys, num_poses, num_iters=0):
"""Plot all possible hypotheses.""" """Plot all possible hypotheses."""
# Get ground truth # Get ground truth
@ -410,7 +422,11 @@ class Experiment:
]) ])
all_results.append((poses, assignment_string)) all_results.append((poses, assignment_string))
plot_all_results(gt, all_results, text=fixed_values_str) plot_all_results(gt,
all_results,
iters=num_iters,
text=fixed_values_str,
filename=f"city10000_results_{num_iters}.svg")
def save_results(self, result, final_key, time_list): def save_results(self, result, final_key, time_list):
"""Save results to file.""" """Save results to file."""