diff --git a/python/gtsam/examples/HybridCity10000.py b/python/gtsam/examples/HybridCity10000.py index d5d86157d..2e846cca0 100644 --- a/python/gtsam/examples/HybridCity10000.py +++ b/python/gtsam/examples/HybridCity10000.py @@ -112,45 +112,44 @@ class City10000Dataset: return None, None -def plot_estimates(gt, - estimates, - fignum: int, - estimate_color=(0.1, 0.1, 0.9, 0.4), - estimate_label="Hybrid Factor Graphs", - text="graph"): +def plot_all_results(ground_truth, + all_results, + estimate_color=(0.1, 0.1, 0.9, 0.4), + estimate_label="Hybrid Factor Graphs"): """Plot the City10000 estimates against the ground truth. Args: - estimates (np.ndarray): The estimates trajectory as xy values. - fignum (int): The figure number for multiple plots. + ground_truth: The ground truth trajectory as xy values. + all_results (List[Tuple(np.ndarray, str)]): All the estimates trajectory as xy values, as well as assginment strings. estimate_color (tuple, optional): The color to use for the graph of estimates. Defaults to (0.1, 0.1, 0.9, 0.4). estimate_label (str, optional): Label for the estimates, used in the legend. Defaults to "Hybrid Factor Graphs". """ - fig = plt.figure(fignum) - ax = fig.gca() - ax.axis('equal') - ax.axis((-75.0, 75.0, -75.0, 75.0)) + print(len(all_results)) + fig, axes = plt.subplots(int(np.ceil(len(all_results) / 2)), 2) + for i, (estimates, text) in enumerate(all_results): + ax = axes[i] + ax.axis('equal') + ax.axis((-75.0, 75.0, -75.0, 75.0)) - gt = gt[:estimates.shape[0]] - ax.plot(gt[:, 0], - gt[:, 1], - '--', - linewidth=1, - color=(0.1, 0.7, 0.1, 0.5), - label="Ground Truth") - ax.plot(estimates[:, 0], - estimates[:, 1], - '-', - linewidth=1, - color=estimate_color, - label=estimate_label) - ax.legend() - fig.text(0.1, 0.03, text) + gt = ground_truth[:estimates.shape[0]] + ax.plot(gt[:, 0], + gt[:, 1], + '--', + linewidth=1, + color=(0.1, 0.7, 0.1, 0.5), + label="Ground Truth") + ax.plot(estimates[:, 0], + estimates[:, 1], + '-', + linewidth=1, + color=estimate_color, + label=estimate_label) + ax.legend() + ax.text(0.0, 100.0, text) - filename = f"city10000_{text.replace(' ', '_')}.svg" - fig.savefig(filename, format="svg") + fig.savefig("city10000_results.svg", format="svg") class Experiment: @@ -269,7 +268,8 @@ class Experiment: num_measurements = len(pose_array) # Take the first one as the initial estimate - odom_pose = pose_array[np.random.choice(num_measurements)] + # odom_pose = pose_array[np.random.choice(num_measurements)] + odom_pose = pose_array[0] if key_s == key_t - 1: # Odometry factor if num_measurements > 1: @@ -358,14 +358,16 @@ class Experiment: gt = np.loadtxt(gtsam.findExampleDataFile("ISAM2_GT_city10000.txt"), delimiter=" ") - # Get all possible assignments - if discrete_keys.size() > 5: - print("Too many discrete keys to plot all hypotheses. Exiting.") - exit(0) + dkeys = gtsam.DiscreteKeys() + for i in range(discrete_keys.size()): + key, cardinality = discrete_keys.at(i) + if key not in self.smoother_.fixedValues().keys(): + dkeys.push_back((key, cardinality)) - all_assignments = gtsam.cartesianProduct(discrete_keys) + all_assignments = gtsam.cartesianProduct(dkeys) - for idx, assignment in enumerate(all_assignments): + all_results = [] + for assignment in all_assignments: result = gtsam.Values() gbn = self.smoother_.hybridBayesNet().choose(assignment) @@ -381,21 +383,18 @@ class Experiment: delta = self.smoother_.hybridBayesNet().optimize(assignment) result.insert_or_assign(self.initial_.retract(delta)) - poses = [] + poses = np.zeros((num_poses, 3)) for i in range(num_poses): pose = result.atPose2(X(i)) - poses.append((pose.x(), pose.y(), pose.theta())) - poses = np.asarray(poses) + poses[i] = np.asarray((pose.x(), pose.y(), pose.theta())) assignment_string = " ".join([ f"{gtsam.DefaultKeyFormatter(k)}={v}" for k, v in assignment.items() ]) + all_results.append((poses, assignment_string)) - plot_estimates(gt, - estimates=poses, - fignum=idx, - text=assignment_string) + plot_all_results(gt, all_results) def save_results(self, result, final_key, time_list): """Save results to file."""