diff --git a/python/gtsam/examples/HybridCity10000.py b/python/gtsam/examples/HybridCity10000.py index 92102a0e3..5f54701d9 100644 --- a/python/gtsam/examples/HybridCity10000.py +++ b/python/gtsam/examples/HybridCity10000.py @@ -117,24 +117,33 @@ class City10000Dataset: def plot_all_results(ground_truth, all_results, + iters=0, estimate_color=(0.1, 0.1, 0.9, 0.4), estimate_label="Hybrid Factor Graphs", - text=""): + text="", + filename="city10000_results.svg"): """Plot the City10000 estimates against the ground truth. Args: ground_truth: The ground truth trajectory as xy values. - all_results (List[Tuple(np.ndarray, str)]): All the estimates trajectory as xy values, as well as assginment strings. + all_results (List[Tuple(np.ndarray, str)]): All the estimates trajectory as xy values, + as well as assginment strings. estimate_color (tuple, optional): The color to use for the graph of estimates. Defaults to (0.1, 0.1, 0.9, 0.4). estimate_label (str, optional): Label for the estimates, used in the legend. Defaults to "Hybrid Factor Graphs". """ - fig, axes = plt.subplots(int(np.ceil(len(all_results) / 2)), 2) + if len(all_results) == 1: + fig, axes = plt.subplots(1, 1) + axes = [axes] + else: + fig, axes = plt.subplots(int(np.ceil(len(all_results) / 2)), 2) + axes = axes.flatten() + for i, (estimates, s) in enumerate(all_results): ax = axes[i] ax.axis('equal') - ax.axis((-75.0, 75.0, -75.0, 75.0)) + ax.axis((-75.0, 100.0, -75.0, 75.0)) gt = ground_truth[:estimates.shape[0]] ax.plot(gt[:, 0], @@ -149,14 +158,17 @@ def plot_all_results(ground_truth, linewidth=1, color=estimate_color, label=estimate_label) - ax.legend() - ax.text(0.0, 100.0, s) + # ax.legend() + # Plot text `s` at (x, y) on axis + ax.text(-60.0, 60.0, s) + + fig.suptitle(f"After {iters} iterations") num_chunks = int(np.ceil(len(text) / 90)) text = "\n".join(text[i * 60:(i + 1) * 60] for i in range(num_chunks)) fig.text(0.0, 0.015, s=text) - fig.savefig("city10000_results.svg", format="svg") + fig.savefig(filename, format="svg") class Experiment: @@ -362,9 +374,9 @@ class Experiment: # TODO Get cardinality from DiscreteFactor discrete_keys.push_back((key, 2)) print("plotting all hypotheses") - self.plot_all_hypotheses(discrete_keys, key_t + 1) + self.plot_all_hypotheses(discrete_keys, key_t + 1, index) - def plot_all_hypotheses(self, discrete_keys, num_poses): + def plot_all_hypotheses(self, discrete_keys, num_poses, num_iters=0): """Plot all possible hypotheses.""" # Get ground truth @@ -410,7 +422,11 @@ class Experiment: ]) all_results.append((poses, assignment_string)) - plot_all_results(gt, all_results, text=fixed_values_str) + plot_all_results(gt, + all_results, + iters=num_iters, + text=fixed_values_str, + filename=f"city10000_results_{num_iters}.svg") def save_results(self, result, final_key, time_list): """Save results to file."""