code to plot 3D covariance ellipsoid

release/4.3a0
Varun Agrawal 2020-03-20 12:03:37 -04:00
parent f7d86a80cf
commit 26d6cb3d6e
1 changed files with 71 additions and 22 deletions

View File

@ -5,8 +5,10 @@ import matplotlib.pyplot as plt
from matplotlib import patches from matplotlib import patches
from mpl_toolkits.mplot3d import Axes3D from mpl_toolkits.mplot3d import Axes3D
import gtsam
def set_axes_equal(ax):
def set_axes_equal(fignum):
""" """
Make axes of 3D plot have equal scale so that spheres appear as spheres, Make axes of 3D plot have equal scale so that spheres appear as spheres,
cubes as cubes, etc.. This is one possible solution to Matplotlib's cubes as cubes, etc.. This is one possible solution to Matplotlib's
@ -14,6 +16,8 @@ def set_axes_equal(ax):
Input Input
ax: a matplotlib axis, e.g., as output from plt.gca(). ax: a matplotlib axis, e.g., as output from plt.gca().
""" """
fig = plt.figure(fignum)
ax = fig.gca(projection='3d')
limits = np.array([ limits = np.array([
ax.get_xlim3d(), ax.get_xlim3d(),
@ -29,6 +33,46 @@ def set_axes_equal(ax):
ax.set_zlim3d([origin[2] - radius, origin[2] + radius]) ax.set_zlim3d([origin[2] - radius, origin[2] + radius])
def ellipsoid(xc, yc, zc, rx, ry, rz, n):
"""Numpy equivalent of Matlab's ellipsoid function"""
u = np.linspace(0, 2*np.pi, n+1)
v = np.linspace(0, np.pi, n+1)
x = -rx * np.outer(np.cos(u), np.sin(v)).T
y = -ry * np.outer(np.sin(u), np.sin(v)).T
z = -rz * np.outer(np.ones_like(u), np.cos(v)).T
return x, y, z
def plot_covariance_ellipse_3d(axes, origin, P, scale=1, n=8, alpha=0.5):
"""
Plots a Gaussian as an uncertainty ellipse
Based on Maybeck Vol 1, page 366
k=2.296 corresponds to 1 std, 68.26% of all probability
k=11.82 corresponds to 3 std, 99.74% of all probability
"""
k = 11.82
U, S, _ = np.linalg.svd(P)
radii = k * np.sqrt(S)
radii = radii * scale
rx, ry, rz = radii
# generate data for "unrotated" ellipsoid
xc, yc, zc = ellipsoid(0, 0, 0, rx, ry, rz, n)
# rotate data with orientation matrix U and center c
data = np.kron(U[:, 0:1], xc) + np.kron(U[:, 1:2], yc) + \
np.kron(U[:, 2:3], zc)
n = data.shape[1]
x = data[0:n, :] + origin[0]
y = data[n:2*n, :] + origin[1]
z = data[2*n:, :] + origin[2]
axes.plot_surface(x, y, z, alpha=alpha, cmap='hot')
def plot_pose2_on_axes(axes, pose, axis_length=0.1, covariance=None): def plot_pose2_on_axes(axes, pose, axis_length=0.1, covariance=None):
"""Plot a 2D pose on given axis 'axes' with given 'axis_length'.""" """Plot a 2D pose on given axis 'axes' with given 'axis_length'."""
# get rotation and translation (center) # get rotation and translation (center)
@ -68,19 +112,21 @@ def plot_pose2(fignum, pose, axis_length=0.1, covariance=None):
plot_pose2_on_axes(axes, pose, axis_length, covariance) plot_pose2_on_axes(axes, pose, axis_length, covariance)
def plot_point3_on_axes(axes, point, linespec): def plot_point3_on_axes(axes, point, linespec, P=None):
"""Plot a 3D point on given axis 'axes' with given 'linespec'.""" """Plot a 3D point on given axis 'axes' with given 'linespec'."""
axes.plot([point.x()], [point.y()], [point.z()], linespec) axes.plot([point.x()], [point.y()], [point.z()], linespec)
if P is not None:
plot_covariance_ellipse_3d(axes, point.vector(), P)
def plot_point3(fignum, point, linespec): def plot_point3(fignum, point, linespec, P=None):
"""Plot a 3D point on given figure with given 'linespec'.""" """Plot a 3D point on given figure with given 'linespec'."""
fig = plt.figure(fignum) fig = plt.figure(fignum)
axes = fig.gca(projection='3d') axes = fig.gca(projection='3d')
plot_point3_on_axes(axes, point, linespec) plot_point3_on_axes(axes, point, linespec, P)
def plot_3d_points(fignum, values, linespec, marginals=None): def plot_3d_points(fignum, values, linespec="g*", marginals=None):
""" """
Plots the Point3s in 'values', with optional covariances. Plots the Point3s in 'values', with optional covariances.
Finds all the Point3 objects in the given Values object and plots them. Finds all the Point3 objects in the given Values object and plots them.
@ -93,23 +139,25 @@ def plot_3d_points(fignum, values, linespec, marginals=None):
# Plot points and covariance matrices # Plot points and covariance matrices
for i in range(keys.size()): for i in range(keys.size()):
try: try:
p = values.atPoint3(keys.at(i)) key = keys.at(i)
# if haveMarginals point = values.atPoint3(key)
# P = marginals.marginalCovariance(key); if marginals is not None:
# gtsam.plot_point3(p, linespec, P); P = marginals.marginalCovariance(key);
# else else:
plot_point3(fignum, p, linespec) P = None
plot_point3(fignum, point, linespec, P)
except RuntimeError: except RuntimeError:
continue continue
# I guess it's not a Point3 # I guess it's not a Point3
def plot_pose3_on_axes(axes, pose, axis_length=0.1): def plot_pose3_on_axes(axes, pose, P=None, scale=1, axis_length=0.1):
"""Plot a 3D pose on given axis 'axes' with given 'axis_length'.""" """Plot a 3D pose on given axis 'axes' with given 'axis_length'."""
# get rotation and translation (center) # get rotation and translation (center)
gRp = pose.rotation().matrix() # rotation from pose to global gRp = pose.rotation().matrix() # rotation from pose to global
t = pose.translation() origin = pose.translation().vector()
origin = np.array([t.x(), t.y(), t.z()])
# draw the camera axes # draw the camera axes
x_axis = origin + gRp[:, 0] * axis_length x_axis = origin + gRp[:, 0] * axis_length
@ -125,17 +173,18 @@ def plot_pose3_on_axes(axes, pose, axis_length=0.1):
axes.plot(line[:, 0], line[:, 1], line[:, 2], 'b-') axes.plot(line[:, 0], line[:, 1], line[:, 2], 'b-')
# plot the covariance # plot the covariance
# TODO (dellaert): make this work if P is not None:
# if (nargin>2) && (~isempty(P)) # covariance matrix in pose coordinate frame
# pPp = P(4:6,4:6); % covariance matrix in pose coordinate frame pPp = P[3:6, 3:6]
# gPp = gRp*pPp*gRp'; % convert the covariance matrix to global coordinate frame # convert the covariance matrix to global coordinate frame
# gtsam.covarianceEllipse3D(origin,gPp); gPp = gRp @ pPp @ gRp.T
# end plot_covariance_ellipse_3d(axes, origin, gPp)
def plot_pose3(fignum, pose, axis_length=0.1): def plot_pose3(fignum, pose, P, axis_length=0.1):
"""Plot a 3D pose on given figure with given 'axis_length'.""" """Plot a 3D pose on given figure with given 'axis_length'."""
# get figure object # get figure object
fig = plt.figure(fignum) fig = plt.figure(fignum)
axes = fig.gca(projection='3d') axes = fig.gca(projection='3d')
plot_pose3_on_axes(axes, pose, axis_length) plot_pose3_on_axes(axes, pose, P=P, axis_length=axis_length)