520 lines
17 KiB
Python
520 lines
17 KiB
Python
# gtsam_plotly.py
|
|
import base64
|
|
from dataclasses import dataclass
|
|
from typing import Any, Callable, Dict, List, Optional, Tuple
|
|
|
|
import graphviz
|
|
import numpy as np
|
|
import plotly.graph_objects as go
|
|
from tqdm.notebook import tqdm # Optional progress bar
|
|
|
|
import gtsam
|
|
|
|
|
|
# --- Dataclass for History ---
|
|
@dataclass
|
|
class SlamFrameData:
|
|
"""Holds all data needed for a single frame of the SLAM animation."""
|
|
|
|
step_index: int
|
|
results: gtsam.Values # Estimates for variables active at this step
|
|
marginals: Optional[gtsam.Marginals] # Marginals for variables active at this step
|
|
graph_dot_string: Optional[str] = None # Graphviz DOT string for visualization
|
|
|
|
|
|
# --- Core Ellipse Calculation & Path Generation ---
|
|
|
|
|
|
def create_ellipse_path_from_cov(
|
|
cx: float, cy: float, cov_matrix: np.ndarray, scale: float = 2.0, N: int = 60
|
|
) -> str:
|
|
"""Generates SVG path string for an ellipse from 2x2 covariance."""
|
|
cov = cov_matrix[:2, :2] + np.eye(2) * 1e-9 # Ensure positive definite 2x2
|
|
try:
|
|
eigvals, eigvecs = np.linalg.eigh(cov)
|
|
eigvals = np.maximum(eigvals, 1e-9) # Ensure positive eigenvalues
|
|
minor_std, major_std = np.sqrt(eigvals) # eigh sorts ascending
|
|
v_minor, v_major = eigvecs[:, 0], eigvecs[:, 1]
|
|
except np.linalg.LinAlgError:
|
|
# Fallback to a small circle if decomposition fails
|
|
radius = 0.1 * scale
|
|
t = np.linspace(0, 2 * np.pi, N)
|
|
x_p = cx + radius * np.cos(t)
|
|
y_p = cy + radius * np.sin(t)
|
|
else:
|
|
# Parametric equation using eigenvectors and eigenvalues
|
|
t = np.linspace(0, 2 * np.pi, N)
|
|
cos_t, sin_t = np.cos(t), np.sin(t)
|
|
x_p = cx + scale * (
|
|
major_std * cos_t * v_major[0] + minor_std * sin_t * v_minor[0]
|
|
)
|
|
y_p = cy + scale * (
|
|
major_std * cos_t * v_major[1] + minor_std * sin_t * v_minor[1]
|
|
)
|
|
|
|
# Build SVG path string
|
|
path = (
|
|
f"M {x_p[0]},{y_p[0]} "
|
|
+ " ".join(f"L{x_},{y_}" for x_, y_ in zip(x_p[1:], y_p[1:]))
|
|
+ " Z"
|
|
)
|
|
return path
|
|
|
|
|
|
# --- Plotly Element Generators ---
|
|
|
|
|
|
def create_gt_landmarks_trace(
|
|
landmarks_gt: Optional[np.ndarray],
|
|
) -> Optional[go.Scatter]:
|
|
"""Creates scatter trace for ground truth landmarks."""
|
|
if landmarks_gt is None or landmarks_gt.size == 0:
|
|
return None
|
|
return go.Scatter(
|
|
x=landmarks_gt[0, :],
|
|
y=landmarks_gt[1, :],
|
|
mode="markers",
|
|
marker=dict(color="black", size=8, symbol="star"),
|
|
name="Landmarks GT",
|
|
)
|
|
|
|
|
|
def create_gt_path_trace(poses_gt: Optional[List[gtsam.Pose2]]) -> Optional[go.Scatter]:
|
|
"""Creates line trace for ground truth path."""
|
|
if not poses_gt:
|
|
return None
|
|
return go.Scatter(
|
|
x=[p.x() for p in poses_gt],
|
|
y=[p.y() for p in poses_gt],
|
|
mode="lines",
|
|
line=dict(color="gray", width=1, dash="dash"),
|
|
name="Path GT",
|
|
)
|
|
|
|
|
|
def create_est_path_trace(
|
|
est_path_x: List[float], est_path_y: List[float]
|
|
) -> go.Scatter:
|
|
"""Creates trace for the estimated path (all poses up to current)."""
|
|
return go.Scatter(
|
|
x=est_path_x,
|
|
y=est_path_y,
|
|
mode="lines+markers",
|
|
line=dict(color="rgba(255, 0, 0, 0.3)", width=1), # Fainter line for history
|
|
marker=dict(size=4, color="red"), # Keep markers prominent
|
|
name="Path Est",
|
|
)
|
|
|
|
|
|
def create_current_pose_trace(
|
|
current_pose: Optional[gtsam.Pose2],
|
|
) -> Optional[go.Scatter]:
|
|
"""Creates a single marker trace for the current estimated pose."""
|
|
if current_pose is None:
|
|
return None
|
|
return go.Scatter(
|
|
x=[current_pose.x()],
|
|
y=[current_pose.y()],
|
|
mode="markers",
|
|
marker=dict(color="magenta", size=10, symbol="cross"),
|
|
name="Current Pose",
|
|
)
|
|
|
|
|
|
def create_est_landmarks_trace(
|
|
est_landmarks_x: List[float], est_landmarks_y: List[float]
|
|
) -> Optional[go.Scatter]:
|
|
"""Creates trace for currently estimated landmarks."""
|
|
if not est_landmarks_x:
|
|
return None
|
|
return go.Scatter(
|
|
x=est_landmarks_x,
|
|
y=est_landmarks_y,
|
|
mode="markers",
|
|
marker=dict(color="blue", size=6, symbol="x"),
|
|
name="Landmarks Est",
|
|
)
|
|
|
|
|
|
def _create_ellipse_shape_dict(
|
|
cx: float, cy: float, cov: np.ndarray, scale: float, fillcolor: str, line_color: str
|
|
) -> Dict[str, Any]:
|
|
"""Helper: Creates dict for a Plotly ellipse shape from covariance."""
|
|
path = create_ellipse_path_from_cov(cx, cy, cov, scale)
|
|
return dict(
|
|
type="path",
|
|
path=path,
|
|
xref="x",
|
|
yref="y",
|
|
fillcolor=fillcolor,
|
|
line_color=line_color,
|
|
opacity=0.7, # Make ellipses slightly transparent
|
|
)
|
|
|
|
|
|
def create_pose_ellipse_shape(
|
|
pose_mean_xy: np.ndarray, pose_cov: np.ndarray, scale: float
|
|
) -> Dict[str, Any]:
|
|
"""Creates shape dict for a pose covariance ellipse."""
|
|
return _create_ellipse_shape_dict(
|
|
cx=pose_mean_xy[0],
|
|
cy=pose_mean_xy[1],
|
|
cov=pose_cov,
|
|
scale=scale,
|
|
fillcolor="rgba(255,0,255,0.2)", # Magenta fill
|
|
line_color="rgba(255,0,255,0.5)", # Magenta line
|
|
)
|
|
|
|
|
|
def create_landmark_ellipse_shape(
|
|
lm_mean_xy: np.ndarray, lm_cov: np.ndarray, scale: float
|
|
) -> Dict[str, Any]:
|
|
"""Creates shape dict for a landmark covariance ellipse."""
|
|
return _create_ellipse_shape_dict(
|
|
cx=lm_mean_xy[0],
|
|
cy=lm_mean_xy[1],
|
|
cov=lm_cov,
|
|
scale=scale,
|
|
fillcolor="rgba(0,0,255,0.1)", # Blue fill
|
|
line_color="rgba(0,0,255,0.3)", # Blue line
|
|
)
|
|
|
|
|
|
def dot_string_to_base64_svg(
|
|
dot_string: Optional[str], engine: str = "neato"
|
|
) -> Optional[str]:
|
|
"""Renders a DOT string to a base64 encoded SVG using graphviz."""
|
|
if not dot_string:
|
|
return None
|
|
try:
|
|
source = graphviz.Source(dot_string, engine=engine)
|
|
svg_bytes = source.pipe(format="svg")
|
|
encoded_string = base64.b64encode(svg_bytes).decode("utf-8")
|
|
return f"data:image/svg+xml;base64,{encoded_string}"
|
|
except graphviz.backend.execute.CalledProcessError as e:
|
|
print(f"Graphviz rendering error ({engine}): {e}")
|
|
return None
|
|
except Exception as e:
|
|
print(f"Unexpected error during Graphviz SVG generation: {e}")
|
|
return None
|
|
|
|
|
|
# --- Frame Content Generation ---
|
|
def generate_frame_content(
|
|
frame_data: SlamFrameData,
|
|
X: Callable[[int], int],
|
|
L: Callable[[int], int],
|
|
max_landmark_index: int,
|
|
ellipse_scale: float = 2.0,
|
|
graphviz_engine: str = "neato",
|
|
verbose: bool = False,
|
|
) -> Tuple[List[go.Scatter], List[Dict[str, Any]], Optional[Dict[str, Any]]]:
|
|
"""Generates dynamic traces, shapes, and layout image for a single frame."""
|
|
k = frame_data.step_index
|
|
# Use the results specific to this frame for current elements
|
|
step_results = frame_data.results
|
|
step_marginals = frame_data.marginals
|
|
|
|
frame_dynamic_traces: List[go.Scatter] = []
|
|
frame_shapes: List[Dict[str, Any]] = []
|
|
layout_image: Optional[Dict[str, Any]] = None
|
|
|
|
# 1. Estimated Path (Full History or Partial)
|
|
est_path_x = []
|
|
est_path_y = []
|
|
current_pose_est = None
|
|
|
|
# Plot poses currently existing in the step_results (e.g., within lag)
|
|
for i in range(k + 1): # Check poses up to current step index
|
|
pose_key = X(i)
|
|
if step_results.exists(pose_key):
|
|
pose = step_results.atPose2(pose_key)
|
|
est_path_x.append(pose.x())
|
|
est_path_y.append(pose.y())
|
|
if i == k:
|
|
current_pose_est = pose
|
|
|
|
path_trace = create_est_path_trace(est_path_x, est_path_y)
|
|
if path_trace:
|
|
frame_dynamic_traces.append(path_trace)
|
|
|
|
# Add a distinct marker for the current pose estimate
|
|
current_pose_trace = create_current_pose_trace(current_pose_est)
|
|
if current_pose_trace:
|
|
frame_dynamic_traces.append(current_pose_trace)
|
|
|
|
# 2. Estimated Landmarks (Only those present in step_results)
|
|
est_landmarks_x, est_landmarks_y, landmark_keys = [], [], []
|
|
for j in range(max_landmark_index + 1):
|
|
lm_key = L(j)
|
|
# Check existence in the results for the *current frame*
|
|
if step_results.exists(lm_key):
|
|
lm_point = step_results.atPoint2(lm_key)
|
|
est_landmarks_x.append(lm_point[0])
|
|
est_landmarks_y.append(lm_point[1])
|
|
landmark_keys.append(lm_key) # Store keys for covariance lookup
|
|
|
|
lm_trace = create_est_landmarks_trace(est_landmarks_x, est_landmarks_y)
|
|
if lm_trace:
|
|
frame_dynamic_traces.append(lm_trace)
|
|
|
|
# 3. Covariance Ellipses (Only for items in step_results AND step_marginals)
|
|
if step_marginals:
|
|
# Current Pose Ellipse
|
|
pose_key = X(k)
|
|
# Retrieve cov from marginals specific to this frame
|
|
cov = step_marginals.marginalCovariance(pose_key)
|
|
# Ensure mean comes from the pose in current results
|
|
mean = step_results.atPose2(pose_key).translation()
|
|
frame_shapes.append(create_pose_ellipse_shape(mean, cov, ellipse_scale))
|
|
|
|
# Landmark Ellipses (Iterate over keys found in step_results)
|
|
for lm_key in landmark_keys:
|
|
try:
|
|
# Retrieve cov from marginals specific to this frame
|
|
cov = step_marginals.marginalCovariance(lm_key)
|
|
# Ensure mean comes from the landmark in current results
|
|
mean = step_results.atPoint2(lm_key)
|
|
frame_shapes.append(
|
|
create_landmark_ellipse_shape(mean, cov, ellipse_scale)
|
|
)
|
|
except RuntimeError: # Covariance might not be available
|
|
if verbose:
|
|
print(
|
|
f"Warn: LM {gtsam.Symbol(lm_key).index()} cov not in marginals @ step {k}"
|
|
)
|
|
except Exception as e:
|
|
if verbose:
|
|
print(
|
|
f"Warn: LM {gtsam.Symbol(lm_key).index()} cov OTHER err @ step {k}: {e}"
|
|
)
|
|
|
|
# 4. Graph Image for Layout
|
|
img_src = dot_string_to_base64_svg(
|
|
frame_data.graph_dot_string, engine=graphviz_engine
|
|
)
|
|
if img_src:
|
|
layout_image = dict(
|
|
source=img_src,
|
|
xref="paper",
|
|
yref="paper",
|
|
x=0,
|
|
y=1,
|
|
sizex=0.48,
|
|
sizey=1,
|
|
xanchor="left",
|
|
yanchor="top",
|
|
layer="below",
|
|
sizing="contain",
|
|
)
|
|
|
|
# Return dynamic elements for this frame
|
|
return frame_dynamic_traces, frame_shapes, layout_image
|
|
|
|
|
|
# --- Figure Configuration ---
|
|
|
|
|
|
def configure_figure_layout(
|
|
fig: go.Figure,
|
|
num_steps: int,
|
|
world_size: float,
|
|
initial_shapes: List[Dict[str, Any]],
|
|
initial_image: Optional[Dict[str, Any]],
|
|
) -> None:
|
|
"""Configures Plotly figure layout for side-by-side display."""
|
|
steps = list(range(num_steps + 1))
|
|
plot_domain = [0.52, 1.0] # Right pane for the SLAM plot
|
|
|
|
sliders = [
|
|
dict(
|
|
active=0,
|
|
currentvalue={"prefix": "Step: "},
|
|
pad={"t": 50},
|
|
steps=[
|
|
dict(
|
|
label=str(k),
|
|
method="animate",
|
|
args=[
|
|
[str(k)],
|
|
dict(
|
|
mode="immediate",
|
|
frame=dict(duration=100, redraw=True),
|
|
transition=dict(duration=0),
|
|
),
|
|
],
|
|
)
|
|
for k in steps
|
|
],
|
|
)
|
|
]
|
|
updatemenus = [
|
|
dict(
|
|
type="buttons",
|
|
showactive=False,
|
|
direction="left",
|
|
pad={"r": 10, "t": 87},
|
|
x=plot_domain[0],
|
|
xanchor="left",
|
|
y=0,
|
|
yanchor="top",
|
|
buttons=[
|
|
dict(
|
|
label="Play",
|
|
method="animate",
|
|
args=[
|
|
None,
|
|
dict(
|
|
mode="immediate",
|
|
frame=dict(duration=100, redraw=True),
|
|
transition=dict(duration=0),
|
|
fromcurrent=True,
|
|
),
|
|
],
|
|
),
|
|
dict(
|
|
label="Pause",
|
|
method="animate",
|
|
args=[
|
|
[None],
|
|
dict(
|
|
mode="immediate",
|
|
frame=dict(duration=0, redraw=False),
|
|
transition=dict(duration=0),
|
|
),
|
|
],
|
|
),
|
|
],
|
|
)
|
|
]
|
|
|
|
fig.update_layout(
|
|
title="Factor Graph SLAM Animation (Graph Left, Results Right)",
|
|
xaxis=dict(
|
|
range=[-world_size / 2 - 2, world_size / 2 + 2],
|
|
domain=plot_domain,
|
|
constrain="domain",
|
|
),
|
|
yaxis=dict(
|
|
range=[-world_size / 2 - 2, world_size / 2 + 2],
|
|
scaleanchor="x",
|
|
scaleratio=1,
|
|
domain=[0, 1],
|
|
),
|
|
width=1000,
|
|
height=600,
|
|
hovermode="closest",
|
|
updatemenus=updatemenus,
|
|
sliders=sliders,
|
|
shapes=initial_shapes, # Initial shapes (frame 0)
|
|
images=([initial_image] if initial_image else []), # Initial image (frame 0)
|
|
showlegend=True, # Keep legend for clarity
|
|
legend=dict(
|
|
x=plot_domain[0],
|
|
y=1,
|
|
traceorder="normal", # Position legend
|
|
bgcolor="rgba(255,255,255,0.5)",
|
|
),
|
|
)
|
|
|
|
|
|
# --- Main Animation Orchestrator ---
|
|
|
|
|
|
def create_slam_animation(
|
|
history: List[SlamFrameData],
|
|
X: Callable[[int], int],
|
|
L: Callable[[int], int],
|
|
max_landmark_index: int,
|
|
landmarks_gt_array: Optional[np.ndarray] = None,
|
|
poses_gt: Optional[List[gtsam.Pose2]] = None,
|
|
world_size: float = 20.0,
|
|
ellipse_scale: float = 2.0,
|
|
graphviz_engine: str = "neato",
|
|
verbose_cov_errors: bool = False,
|
|
) -> go.Figure:
|
|
"""Creates a side-by-side Plotly SLAM animation using a history of dataclasses."""
|
|
if not history:
|
|
raise ValueError("History cannot be empty.")
|
|
print("Generating Plotly animation...")
|
|
num_steps = history[-1].step_index
|
|
fig = go.Figure()
|
|
|
|
# 1. Create static GT traces ONCE
|
|
gt_traces = []
|
|
gt_lm_trace = create_gt_landmarks_trace(landmarks_gt_array)
|
|
if gt_lm_trace:
|
|
gt_traces.append(gt_lm_trace)
|
|
gt_path_trace = create_gt_path_trace(poses_gt)
|
|
if gt_path_trace:
|
|
gt_traces.append(gt_path_trace)
|
|
|
|
# 2. Generate content for the initial frame (k=0) to set up the figure
|
|
initial_frame_data = next((item for item in history if item.step_index == 0), None)
|
|
if initial_frame_data is None:
|
|
raise ValueError("History must contain data for step 0.")
|
|
|
|
(
|
|
initial_dynamic_traces,
|
|
initial_shapes,
|
|
initial_image,
|
|
) = generate_frame_content(
|
|
initial_frame_data,
|
|
X,
|
|
L,
|
|
max_landmark_index,
|
|
ellipse_scale,
|
|
graphviz_engine,
|
|
verbose_cov_errors,
|
|
)
|
|
|
|
# 3. Add initial traces (GT + dynamic frame 0)
|
|
for trace in gt_traces:
|
|
fig.add_trace(trace)
|
|
for trace in initial_dynamic_traces:
|
|
fig.add_trace(trace)
|
|
|
|
# 4. Generate frames for the animation (k=0 to num_steps)
|
|
frames = []
|
|
steps_iterable = range(num_steps + 1)
|
|
steps_iterable = tqdm(steps_iterable, desc="Creating Frames")
|
|
|
|
for k in steps_iterable:
|
|
frame_data = next((item for item in history if item.step_index == k), None)
|
|
|
|
# Generate dynamic content specific to this frame
|
|
frame_dynamic_traces, frame_shapes, layout_image = generate_frame_content(
|
|
frame_data,
|
|
X,
|
|
L,
|
|
max_landmark_index,
|
|
ellipse_scale,
|
|
graphviz_engine,
|
|
verbose_cov_errors,
|
|
)
|
|
|
|
# Frame definition: includes static GT + dynamic traces for this step
|
|
# Layout updates only include shapes and images for this step
|
|
frames.append(
|
|
go.Frame(
|
|
data=gt_traces
|
|
+ frame_dynamic_traces, # GT must be in each frame's data
|
|
name=str(k),
|
|
layout=go.Layout(
|
|
shapes=frame_shapes, # Replaces shapes list for this frame
|
|
images=(
|
|
[layout_image] if layout_image else []
|
|
), # Replaces image list
|
|
),
|
|
)
|
|
)
|
|
|
|
# 5. Assign frames to the figure
|
|
fig.update(frames=frames)
|
|
|
|
# 6. Configure overall layout (sliders, buttons, axes, etc.)
|
|
configure_figure_layout(fig, num_steps, world_size, initial_shapes, initial_image)
|
|
|
|
print("Plotly animation generated.")
|
|
return fig
|