gtsam/python/gtsam/examples/gtsam_plotly.py

520 lines
17 KiB
Python

# gtsam_plotly.py
import base64
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Tuple
import graphviz
import numpy as np
import plotly.graph_objects as go
from tqdm.notebook import tqdm # Optional progress bar
import gtsam
# --- Dataclass for History ---
@dataclass
class SlamFrameData:
"""Holds all data needed for a single frame of the SLAM animation."""
step_index: int
results: gtsam.Values # Estimates for variables active at this step
marginals: Optional[gtsam.Marginals] # Marginals for variables active at this step
graph_dot_string: Optional[str] = None # Graphviz DOT string for visualization
# --- Core Ellipse Calculation & Path Generation ---
def create_ellipse_path_from_cov(
cx: float, cy: float, cov_matrix: np.ndarray, scale: float = 2.0, N: int = 60
) -> str:
"""Generates SVG path string for an ellipse from 2x2 covariance."""
cov = cov_matrix[:2, :2] + np.eye(2) * 1e-9 # Ensure positive definite 2x2
try:
eigvals, eigvecs = np.linalg.eigh(cov)
eigvals = np.maximum(eigvals, 1e-9) # Ensure positive eigenvalues
minor_std, major_std = np.sqrt(eigvals) # eigh sorts ascending
v_minor, v_major = eigvecs[:, 0], eigvecs[:, 1]
except np.linalg.LinAlgError:
# Fallback to a small circle if decomposition fails
radius = 0.1 * scale
t = np.linspace(0, 2 * np.pi, N)
x_p = cx + radius * np.cos(t)
y_p = cy + radius * np.sin(t)
else:
# Parametric equation using eigenvectors and eigenvalues
t = np.linspace(0, 2 * np.pi, N)
cos_t, sin_t = np.cos(t), np.sin(t)
x_p = cx + scale * (
major_std * cos_t * v_major[0] + minor_std * sin_t * v_minor[0]
)
y_p = cy + scale * (
major_std * cos_t * v_major[1] + minor_std * sin_t * v_minor[1]
)
# Build SVG path string
path = (
f"M {x_p[0]},{y_p[0]} "
+ " ".join(f"L{x_},{y_}" for x_, y_ in zip(x_p[1:], y_p[1:]))
+ " Z"
)
return path
# --- Plotly Element Generators ---
def create_gt_landmarks_trace(
landmarks_gt: Optional[np.ndarray],
) -> Optional[go.Scatter]:
"""Creates scatter trace for ground truth landmarks."""
if landmarks_gt is None or landmarks_gt.size == 0:
return None
return go.Scatter(
x=landmarks_gt[0, :],
y=landmarks_gt[1, :],
mode="markers",
marker=dict(color="black", size=8, symbol="star"),
name="Landmarks GT",
)
def create_gt_path_trace(poses_gt: Optional[List[gtsam.Pose2]]) -> Optional[go.Scatter]:
"""Creates line trace for ground truth path."""
if not poses_gt:
return None
return go.Scatter(
x=[p.x() for p in poses_gt],
y=[p.y() for p in poses_gt],
mode="lines",
line=dict(color="gray", width=1, dash="dash"),
name="Path GT",
)
def create_est_path_trace(
est_path_x: List[float], est_path_y: List[float]
) -> go.Scatter:
"""Creates trace for the estimated path (all poses up to current)."""
return go.Scatter(
x=est_path_x,
y=est_path_y,
mode="lines+markers",
line=dict(color="rgba(255, 0, 0, 0.3)", width=1), # Fainter line for history
marker=dict(size=4, color="red"), # Keep markers prominent
name="Path Est",
)
def create_current_pose_trace(
current_pose: Optional[gtsam.Pose2],
) -> Optional[go.Scatter]:
"""Creates a single marker trace for the current estimated pose."""
if current_pose is None:
return None
return go.Scatter(
x=[current_pose.x()],
y=[current_pose.y()],
mode="markers",
marker=dict(color="magenta", size=10, symbol="cross"),
name="Current Pose",
)
def create_est_landmarks_trace(
est_landmarks_x: List[float], est_landmarks_y: List[float]
) -> Optional[go.Scatter]:
"""Creates trace for currently estimated landmarks."""
if not est_landmarks_x:
return None
return go.Scatter(
x=est_landmarks_x,
y=est_landmarks_y,
mode="markers",
marker=dict(color="blue", size=6, symbol="x"),
name="Landmarks Est",
)
def _create_ellipse_shape_dict(
cx: float, cy: float, cov: np.ndarray, scale: float, fillcolor: str, line_color: str
) -> Dict[str, Any]:
"""Helper: Creates dict for a Plotly ellipse shape from covariance."""
path = create_ellipse_path_from_cov(cx, cy, cov, scale)
return dict(
type="path",
path=path,
xref="x",
yref="y",
fillcolor=fillcolor,
line_color=line_color,
opacity=0.7, # Make ellipses slightly transparent
)
def create_pose_ellipse_shape(
pose_mean_xy: np.ndarray, pose_cov: np.ndarray, scale: float
) -> Dict[str, Any]:
"""Creates shape dict for a pose covariance ellipse."""
return _create_ellipse_shape_dict(
cx=pose_mean_xy[0],
cy=pose_mean_xy[1],
cov=pose_cov,
scale=scale,
fillcolor="rgba(255,0,255,0.2)", # Magenta fill
line_color="rgba(255,0,255,0.5)", # Magenta line
)
def create_landmark_ellipse_shape(
lm_mean_xy: np.ndarray, lm_cov: np.ndarray, scale: float
) -> Dict[str, Any]:
"""Creates shape dict for a landmark covariance ellipse."""
return _create_ellipse_shape_dict(
cx=lm_mean_xy[0],
cy=lm_mean_xy[1],
cov=lm_cov,
scale=scale,
fillcolor="rgba(0,0,255,0.1)", # Blue fill
line_color="rgba(0,0,255,0.3)", # Blue line
)
def dot_string_to_base64_svg(
dot_string: Optional[str], engine: str = "neato"
) -> Optional[str]:
"""Renders a DOT string to a base64 encoded SVG using graphviz."""
if not dot_string:
return None
try:
source = graphviz.Source(dot_string, engine=engine)
svg_bytes = source.pipe(format="svg")
encoded_string = base64.b64encode(svg_bytes).decode("utf-8")
return f"data:image/svg+xml;base64,{encoded_string}"
except graphviz.backend.execute.CalledProcessError as e:
print(f"Graphviz rendering error ({engine}): {e}")
return None
except Exception as e:
print(f"Unexpected error during Graphviz SVG generation: {e}")
return None
# --- Frame Content Generation ---
def generate_frame_content(
frame_data: SlamFrameData,
X: Callable[[int], int],
L: Callable[[int], int],
max_landmark_index: int,
ellipse_scale: float = 2.0,
graphviz_engine: str = "neato",
verbose: bool = False,
) -> Tuple[List[go.Scatter], List[Dict[str, Any]], Optional[Dict[str, Any]]]:
"""Generates dynamic traces, shapes, and layout image for a single frame."""
k = frame_data.step_index
# Use the results specific to this frame for current elements
step_results = frame_data.results
step_marginals = frame_data.marginals
frame_dynamic_traces: List[go.Scatter] = []
frame_shapes: List[Dict[str, Any]] = []
layout_image: Optional[Dict[str, Any]] = None
# 1. Estimated Path (Full History or Partial)
est_path_x = []
est_path_y = []
current_pose_est = None
# Plot poses currently existing in the step_results (e.g., within lag)
for i in range(k + 1): # Check poses up to current step index
pose_key = X(i)
if step_results.exists(pose_key):
pose = step_results.atPose2(pose_key)
est_path_x.append(pose.x())
est_path_y.append(pose.y())
if i == k:
current_pose_est = pose
path_trace = create_est_path_trace(est_path_x, est_path_y)
if path_trace:
frame_dynamic_traces.append(path_trace)
# Add a distinct marker for the current pose estimate
current_pose_trace = create_current_pose_trace(current_pose_est)
if current_pose_trace:
frame_dynamic_traces.append(current_pose_trace)
# 2. Estimated Landmarks (Only those present in step_results)
est_landmarks_x, est_landmarks_y, landmark_keys = [], [], []
for j in range(max_landmark_index + 1):
lm_key = L(j)
# Check existence in the results for the *current frame*
if step_results.exists(lm_key):
lm_point = step_results.atPoint2(lm_key)
est_landmarks_x.append(lm_point[0])
est_landmarks_y.append(lm_point[1])
landmark_keys.append(lm_key) # Store keys for covariance lookup
lm_trace = create_est_landmarks_trace(est_landmarks_x, est_landmarks_y)
if lm_trace:
frame_dynamic_traces.append(lm_trace)
# 3. Covariance Ellipses (Only for items in step_results AND step_marginals)
if step_marginals:
# Current Pose Ellipse
pose_key = X(k)
# Retrieve cov from marginals specific to this frame
cov = step_marginals.marginalCovariance(pose_key)
# Ensure mean comes from the pose in current results
mean = step_results.atPose2(pose_key).translation()
frame_shapes.append(create_pose_ellipse_shape(mean, cov, ellipse_scale))
# Landmark Ellipses (Iterate over keys found in step_results)
for lm_key in landmark_keys:
try:
# Retrieve cov from marginals specific to this frame
cov = step_marginals.marginalCovariance(lm_key)
# Ensure mean comes from the landmark in current results
mean = step_results.atPoint2(lm_key)
frame_shapes.append(
create_landmark_ellipse_shape(mean, cov, ellipse_scale)
)
except RuntimeError: # Covariance might not be available
if verbose:
print(
f"Warn: LM {gtsam.Symbol(lm_key).index()} cov not in marginals @ step {k}"
)
except Exception as e:
if verbose:
print(
f"Warn: LM {gtsam.Symbol(lm_key).index()} cov OTHER err @ step {k}: {e}"
)
# 4. Graph Image for Layout
img_src = dot_string_to_base64_svg(
frame_data.graph_dot_string, engine=graphviz_engine
)
if img_src:
layout_image = dict(
source=img_src,
xref="paper",
yref="paper",
x=0,
y=1,
sizex=0.48,
sizey=1,
xanchor="left",
yanchor="top",
layer="below",
sizing="contain",
)
# Return dynamic elements for this frame
return frame_dynamic_traces, frame_shapes, layout_image
# --- Figure Configuration ---
def configure_figure_layout(
fig: go.Figure,
num_steps: int,
world_size: float,
initial_shapes: List[Dict[str, Any]],
initial_image: Optional[Dict[str, Any]],
) -> None:
"""Configures Plotly figure layout for side-by-side display."""
steps = list(range(num_steps + 1))
plot_domain = [0.52, 1.0] # Right pane for the SLAM plot
sliders = [
dict(
active=0,
currentvalue={"prefix": "Step: "},
pad={"t": 50},
steps=[
dict(
label=str(k),
method="animate",
args=[
[str(k)],
dict(
mode="immediate",
frame=dict(duration=100, redraw=True),
transition=dict(duration=0),
),
],
)
for k in steps
],
)
]
updatemenus = [
dict(
type="buttons",
showactive=False,
direction="left",
pad={"r": 10, "t": 87},
x=plot_domain[0],
xanchor="left",
y=0,
yanchor="top",
buttons=[
dict(
label="Play",
method="animate",
args=[
None,
dict(
mode="immediate",
frame=dict(duration=100, redraw=True),
transition=dict(duration=0),
fromcurrent=True,
),
],
),
dict(
label="Pause",
method="animate",
args=[
[None],
dict(
mode="immediate",
frame=dict(duration=0, redraw=False),
transition=dict(duration=0),
),
],
),
],
)
]
fig.update_layout(
title="Factor Graph SLAM Animation (Graph Left, Results Right)",
xaxis=dict(
range=[-world_size / 2 - 2, world_size / 2 + 2],
domain=plot_domain,
constrain="domain",
),
yaxis=dict(
range=[-world_size / 2 - 2, world_size / 2 + 2],
scaleanchor="x",
scaleratio=1,
domain=[0, 1],
),
width=1000,
height=600,
hovermode="closest",
updatemenus=updatemenus,
sliders=sliders,
shapes=initial_shapes, # Initial shapes (frame 0)
images=([initial_image] if initial_image else []), # Initial image (frame 0)
showlegend=True, # Keep legend for clarity
legend=dict(
x=plot_domain[0],
y=1,
traceorder="normal", # Position legend
bgcolor="rgba(255,255,255,0.5)",
),
)
# --- Main Animation Orchestrator ---
def create_slam_animation(
history: List[SlamFrameData],
X: Callable[[int], int],
L: Callable[[int], int],
max_landmark_index: int,
landmarks_gt_array: Optional[np.ndarray] = None,
poses_gt: Optional[List[gtsam.Pose2]] = None,
world_size: float = 20.0,
ellipse_scale: float = 2.0,
graphviz_engine: str = "neato",
verbose_cov_errors: bool = False,
) -> go.Figure:
"""Creates a side-by-side Plotly SLAM animation using a history of dataclasses."""
if not history:
raise ValueError("History cannot be empty.")
print("Generating Plotly animation...")
num_steps = history[-1].step_index
fig = go.Figure()
# 1. Create static GT traces ONCE
gt_traces = []
gt_lm_trace = create_gt_landmarks_trace(landmarks_gt_array)
if gt_lm_trace:
gt_traces.append(gt_lm_trace)
gt_path_trace = create_gt_path_trace(poses_gt)
if gt_path_trace:
gt_traces.append(gt_path_trace)
# 2. Generate content for the initial frame (k=0) to set up the figure
initial_frame_data = next((item for item in history if item.step_index == 0), None)
if initial_frame_data is None:
raise ValueError("History must contain data for step 0.")
(
initial_dynamic_traces,
initial_shapes,
initial_image,
) = generate_frame_content(
initial_frame_data,
X,
L,
max_landmark_index,
ellipse_scale,
graphviz_engine,
verbose_cov_errors,
)
# 3. Add initial traces (GT + dynamic frame 0)
for trace in gt_traces:
fig.add_trace(trace)
for trace in initial_dynamic_traces:
fig.add_trace(trace)
# 4. Generate frames for the animation (k=0 to num_steps)
frames = []
steps_iterable = range(num_steps + 1)
steps_iterable = tqdm(steps_iterable, desc="Creating Frames")
for k in steps_iterable:
frame_data = next((item for item in history if item.step_index == k), None)
# Generate dynamic content specific to this frame
frame_dynamic_traces, frame_shapes, layout_image = generate_frame_content(
frame_data,
X,
L,
max_landmark_index,
ellipse_scale,
graphviz_engine,
verbose_cov_errors,
)
# Frame definition: includes static GT + dynamic traces for this step
# Layout updates only include shapes and images for this step
frames.append(
go.Frame(
data=gt_traces
+ frame_dynamic_traces, # GT must be in each frame's data
name=str(k),
layout=go.Layout(
shapes=frame_shapes, # Replaces shapes list for this frame
images=(
[layout_image] if layout_image else []
), # Replaces image list
),
)
)
# 5. Assign frames to the figure
fig.update(frames=frames)
# 6. Configure overall layout (sliders, buttons, axes, etc.)
configure_figure_layout(fig, num_steps, world_size, initial_shapes, initial_image)
print("Plotly animation generated.")
return fig