From 6660e2a532a76656e75507cdb10e46a12b8c4161 Mon Sep 17 00:00:00 2001 From: Varun Agrawal Date: Fri, 24 Jul 2020 09:43:37 -0500 Subject: [PATCH] added axis labels and figure titles as optional params --- cython/gtsam/utils/plot.py | 68 ++++++++++++++++++++++++++++++-------- 1 file changed, 54 insertions(+), 14 deletions(-) diff --git a/cython/gtsam/utils/plot.py b/cython/gtsam/utils/plot.py index 93d8ba47b..1e976a69e 100644 --- a/cython/gtsam/utils/plot.py +++ b/cython/gtsam/utils/plot.py @@ -135,7 +135,8 @@ def plot_pose2_on_axes(axes, pose, axis_length=0.1, covariance=None): axes.add_patch(e1) -def plot_pose2(fignum, pose, axis_length=0.1, covariance=None): +def plot_pose2(fignum, pose, axis_length=0.1, covariance=None, + axis_labels=('X axis', 'Y axis', 'Z axis')): """ Plot a 2D pose on given figure with given `axis_length`. @@ -144,6 +145,7 @@ def plot_pose2(fignum, pose, axis_length=0.1, covariance=None): pose (gtsam.Pose2): The pose to be plotted. axis_length (float): The length of the camera axes. covariance (numpy.ndarray): Marginal covariance matrix to plot the uncertainty of the estimation. + axis_labels (iterable[string]): List of axis labels to set. """ # get figure object fig = plt.figure(fignum) @@ -151,6 +153,12 @@ def plot_pose2(fignum, pose, axis_length=0.1, covariance=None): plot_pose2_on_axes(axes, pose, axis_length=axis_length, covariance=covariance) + axes.set_xlabel(axis_labels[0]) + axes.set_ylabel(axis_labels[1]) + axes.set_zlabel(axis_labels[2]) + + return fig + def plot_point3_on_axes(axes, point, linespec, P=None): """ @@ -167,7 +175,8 @@ def plot_point3_on_axes(axes, point, linespec, P=None): plot_covariance_ellipse_3d(axes, point.vector(), P) -def plot_point3(fignum, point, linespec, P=None): +def plot_point3(fignum, point, linespec, P=None, + axis_labels=('X axis', 'Y axis', 'Z axis')): """ Plot a 3D point on given figure with given `linespec`. @@ -176,13 +185,25 @@ def plot_point3(fignum, point, linespec, P=None): point (gtsam.Point3): The point to be plotted. linespec (string): String representing formatting options for Matplotlib. P (numpy.ndarray): Marginal covariance matrix to plot the uncertainty of the estimation. + axis_labels (iterable[string]): List of axis labels to set. + + Returns: + fig: The matplotlib figure. + """ fig = plt.figure(fignum) axes = fig.gca(projection='3d') plot_point3_on_axes(axes, point, linespec, P) + axes.set_xlabel(axis_labels[0]) + axes.set_ylabel(axis_labels[1]) + axes.set_zlabel(axis_labels[2]) -def plot_3d_points(fignum, values, linespec="g*", marginals=None): + return fig + + +def plot_3d_points(fignum, values, linespec="g*", marginals=None, + title="3D Points", axis_labels=('X axis', 'Y axis', 'Z axis')): """ Plots the Point3s in `values`, with optional covariances. Finds all the Point3 objects in the given Values object and plots them. @@ -193,7 +214,9 @@ def plot_3d_points(fignum, values, linespec="g*", marginals=None): fignum (int): Integer representing the figure number to use for plotting. values (gtsam.Values): Values dictionary consisting of points to be plotted. linespec (string): String representing formatting options for Matplotlib. - covariance (numpy.ndarray): Marginal covariance matrix to plot the uncertainty of the estimation. + marginals (numpy.ndarray): Marginal covariance matrix to plot the uncertainty of the estimation. + title (string): The title of the plot. + axis_labels (iterable[string]): List of axis labels to set. """ keys = values.keys() @@ -208,12 +231,16 @@ def plot_3d_points(fignum, values, linespec="g*", marginals=None): else: covariance = None - plot_point3(fignum, point, linespec, covariance) + fig = plot_point3(fignum, point, linespec, covariance, + axis_labels=axis_labels) except RuntimeError: continue # I guess it's not a Point3 + fig.suptitle(title) + fig.canvas.set_window_title(title.lower()) + def plot_pose3_on_axes(axes, pose, axis_length=0.1, P=None, scale=1): """ @@ -251,7 +278,8 @@ def plot_pose3_on_axes(axes, pose, axis_length=0.1, P=None, scale=1): plot_covariance_ellipse_3d(axes, origin, gPp) -def plot_pose3(fignum, pose, axis_length=0.1, P=None): +def plot_pose3(fignum, pose, axis_length=0.1, P=None, + axis_labels=('X axis', 'Y axis', 'Z axis')): """ Plot a 3D pose on given figure with given `axis_length`. @@ -260,6 +288,10 @@ def plot_pose3(fignum, pose, axis_length=0.1, P=None): pose (gtsam.Pose3): 3D pose to be plotted. linespec (string): String representing formatting options for Matplotlib. P (numpy.ndarray): Marginal covariance matrix to plot the uncertainty of the estimation. + axis_labels (iterable[string]): List of axis labels to set. + + Returns: + fig: The matplotlib figure. """ # get figure object fig = plt.figure(fignum) @@ -267,12 +299,15 @@ def plot_pose3(fignum, pose, axis_length=0.1, P=None): plot_pose3_on_axes(axes, pose, P=P, axis_length=axis_length) - axes.set_xlabel('X axis') - axes.set_ylabel('Y axis') - axes.set_zlabel('Z axis') + axes.set_xlabel(axis_labels[0]) + axes.set_ylabel(axis_labels[1]) + axes.set_zlabel(axis_labels[2]) + + return fig -def plot_trajectory(fignum, values, scale=1, marginals=None): +def plot_trajectory(fignum, values, scale=1, marginals=None, + title="Plot Trajectory", axis_labels=('X axis', 'Y axis', 'Z axis')): """ Plot a complete 3D trajectory using poses in `values`. @@ -282,6 +317,8 @@ def plot_trajectory(fignum, values, scale=1, marginals=None): scale (float): Value to scale the poses by. marginals (gtsam.Marginals): Marginalized probability values of the estimation. Used to plot uncertainty bounds. + title (string): The title of the plot. + axis_labels (iterable[string]): List of axis labels to set. """ pose3Values = gtsam.utilities_allPose3s(values) keys = gtsam.KeyVector(pose3Values.keys()) @@ -307,8 +344,8 @@ def plot_trajectory(fignum, values, scale=1, marginals=None): else: covariance = None - plot_pose3(fignum, lastPose, P=covariance, - axis_length=scale) + fig = plot_pose3(fignum, lastPose, P=covariance, + axis_length=scale, axis_labels=axis_labels) lastIndex = i @@ -322,8 +359,11 @@ def plot_trajectory(fignum, values, scale=1, marginals=None): else: covariance = None - plot_pose3(fignum, lastPose, P=covariance, - axis_length=scale) + fig = plot_pose3(fignum, lastPose, P=covariance, + axis_length=scale, axis_labels=axis_labels) except: pass + + fig.suptitle(title) + fig.canvas.set_window_title(title.lower())