Plot side by side

release/4.3a0
Frank Dellaert 2025-04-24 16:42:23 -04:00
parent 1bf76be62b
commit e4278687b4
2 changed files with 221 additions and 199 deletions

View File

@ -106,8 +106,6 @@
"outputs": [],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt # For initial plot if desired (optional)\n",
"import plotly.graph_objects as go\n",
"from tqdm.notebook import tqdm # Progress bar\n",
"import math\n",
"import time # To slow down graphviz display if needed\n",
@ -117,11 +115,10 @@
"\n",
"# Imports for new modules\n",
"import simulation\n",
"import gtsam_plotly \n",
"from gtsam_plotly import SlamFrameData, create_slam_animation\n",
"\n",
"# Imports for graph visualization\n",
"import graphviz\n",
"from IPython.display import display, clear_output"
"import graphviz"
]
},
{
@ -158,11 +155,7 @@
"MEASUREMENT_NOISE = gtsam.noiseModel.Diagonal.Sigmas(np.array([np.deg2rad(2.0), 0.2]))\n",
"\n",
"# Sensor parameters\n",
"MAX_SENSOR_RANGE = 5.0\n",
"\n",
"# Visualization parameters\n",
"SHOW_GRAPHVIZ_EACH_STEP = True # Set to False to only show final graph\n",
"GRAPHVIZ_PAUSE_SECONDS = 1 # Pause briefly to allow viewing the graph"
"MAX_SENSOR_RANGE = 5.0"
]
},
{
@ -215,15 +208,11 @@
"metadata": {},
"outputs": [],
"source": [
"writer = gtsam.GraphvizFormatting()\n",
"writer.binaryEdges = True\n",
"WRITER = gtsam.GraphvizFormatting()\n",
"WRITER.binaryEdges = True\n",
"\n",
"def maybe_show(graph, estimate, message):\n",
" if SHOW_GRAPHVIZ_EACH_STEP:\n",
" print(message)\n",
" graph_dot = graph.dot(estimate, writer=writer)\n",
" display(graphviz.Source(graph_dot, engine='neato'))\n",
" time.sleep(GRAPHVIZ_PAUSE_SECONDS)"
"def mage_dot(graph, estimate):\n",
" return graph.dot(estimate, writer=WRITER)\n"
]
},
{
@ -248,20 +237,13 @@
"current_graph.add(gtsam.PriorFactorPose2(current_pose_key, initial_pose, PRIOR_NOISE))\n",
"\n",
"# Variables to store results for animation\n",
"results_history = [gtsam.Values(current_estimate)] # Store Values object at each step\n",
"marginals_history = [] # Store Marginals object at each step\n",
"history = [] # Store SLAMFrameData objects for each step\n",
"known_landmarks = set() # Set of landmark keys L(j) currently in the state\n",
"\n",
"# Initial marginals (only for X(0))\n",
"try:\n",
" initial_gaussian_graph = current_graph.linearize(current_estimate)\n",
" initial_marginals = gtsam.Marginals(initial_gaussian_graph, current_estimate)\n",
" marginals_history.append(initial_marginals)\n",
"except Exception as e:\n",
" print(f\"Error computing initial marginals: {e}\")\n",
" marginals_history.append(None) # Append placeholder if fails\n",
"\n",
"maybe_show(current_graph, current_estimate, \"--- Step 0 --- Initial Graph with Prior ---\")"
"initial_gaussian_graph = current_graph.linearize(current_estimate)\n",
"initial_marginals = gtsam.Marginals(initial_gaussian_graph, current_estimate)\n",
"history.append(SlamFrameData(0, current_estimate, initial_marginals, mage_dot(current_graph, current_estimate)))"
]
},
{
@ -324,16 +306,12 @@
" optimizer = gtsam.LevenbergMarquardtOptimizer(current_graph, current_estimate)\n",
" current_estimate = optimizer.optimize() # Update estimates by optimizing the whole graph so far\n",
" \n",
" # --- Display Factor Graph --- \n",
" maybe_show(current_graph, current_estimate, f\"--- Step {k+1} --- Factor Graph ---\")\n",
"\n",
" # Store results for animation\n",
" results_history.append(gtsam.Values(current_estimate)) # Store a copy\n",
" \n",
" # Calculate marginals for visualization (can be slow for large graphs)\n",
" current_gaussian_graph = current_graph.linearize(current_estimate)\n",
" current_marginals = gtsam.Marginals(current_gaussian_graph, current_estimate)\n",
" marginals_history.append(current_marginals)\n",
"\n",
" # Store the current state for visualization\n",
" history.append(SlamFrameData(k+1, current_estimate, current_marginals, mage_dot(current_graph, current_estimate)))\n",
"\n",
"print(\"\\nIterative Factor Graph SLAM finished.\")\n",
"print(f\"Final number of poses estimated: {len([key for key in current_estimate.keys() if gtsam.Symbol(key).chr() == ord('x')])}\")\n",
@ -360,15 +338,14 @@
},
"outputs": [],
"source": [
"fig = gtsam_plotly.create_slam_animation(\n",
" results_history=results_history,\n",
" marginals_history=marginals_history,\n",
"fig = create_slam_animation(\n",
" history,\n",
" landmarks_gt_array=landmarks_gt_array,\n",
" poses_gt=poses_gt,\n",
" num_steps=NUM_STEPS,\n",
" world_size=WORLD_SIZE,\n",
" X=X, # Pass symbol functions\n",
" L=L\n",
" X=X, # Pass symbol functions\n",
" L=L,\n",
" max_landmark_index=NUM_LANDMARKS,\n",
")\n",
"\n",
"print(\"Displaying animation...\")\n",
@ -384,7 +361,7 @@
"\n",
"* **Approach:** This notebook implemented SLAM iteratively using GTSAM factor graphs. At each step, new factors (odometry, measurements) were added, and the graph was re-optimized using Levenberg-Marquardt. This is more akin to **incremental smoothing** than a classic filter (which would explicitly marginalize past states).\n",
"* **Modularity:** Simulation and Plotly visualization code have been moved into separate `simulation.py` and `gtsam_plotly.py` files for better organization.\n",
"* **Graph Visualization:** The `graphviz` library was used to render the factor graph at each step (or only at the end, depending on `SHOW_GRAPHVIZ_EACH_STEP`). This helps visualize how the graph structure grows and connects poses and landmarks.\n",
"* **Graph Visualization:** The `graphviz` library was used to render the factor graph at each step. This helps visualize how the graph structure grows and connects poses and landmarks.\n",
"* **Efficiency:** Optimizing the entire graph at every step becomes computationally expensive for long trajectories. For real-time performance or large problems, **iSAM2 (Incremental Smoothing and Mapping)** is the preferred GTSAM algorithm. iSAM2 efficiently updates the solution by only re-linearizing and re-solving parts of the graph affected by new measurements.\n",
"* **Accuracy vs. EKF:** This factor graph approach generally handles non-linearities better than a standard EKF because it re-linearizes during optimization. It avoids some of the consistency pitfalls associated with the EKF's single linearization point per step.\n",
"* **Visualization:** The Plotly animation shows the evolution of the robot's path estimate, the map of landmarks, and their associated uncertainties (covariance ellipses). You can observe how measurements help refine the estimates and reduce uncertainty, especially when loops are closed (implicitly here through repeated observations of landmarks)."

View File

@ -1,63 +1,78 @@
# gtsam_plotly_modular_v2.py
# gtsam_plotly_modular_v3.py
import base64
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional, Tuple
import graphviz
import numpy as np
import plotly.graph_objects as go
from tqdm.notebook import tqdm
import gtsam
# --- Core Ellipse Calculations ---
# --- Dataclass for History ---
@dataclass
class SlamFrameData:
"""Holds all data needed for a single frame of the SLAM animation."""
step_index: int
results: gtsam.Values
marginals: Optional[gtsam.Marginals]
graph_dot_string: Optional[str] = None # Store the Graphviz DOT string
def ellipse_path(
cx: float, cy: float, sizex: float, sizey: float, angle: float, N: int = 60
# --- Core Ellipse Calculation & Path Generation ---
def create_ellipse_path_from_cov(
cx: float, cy: float, cov_matrix: np.ndarray, scale: float = 2.0, N: int = 60
) -> str:
"""Generates SVG path string for a rotated ellipse."""
angle_rad = np.radians(angle)
t = np.linspace(0, 2 * np.pi, N)
x_unit = (sizex / 2) * np.cos(t)
y_unit = (sizey / 2) * np.sin(t)
x_rot = cx + x_unit * np.cos(angle_rad) - y_unit * np.sin(angle_rad)
y_rot = cy + x_unit * np.sin(angle_rad) + y_unit * np.cos(angle_rad)
"""Generates SVG path string for an ellipse directly from 2x2 covariance."""
cov = cov_matrix[:2, :2] + np.eye(2) * 1e-9 # Ensure positive definite
try:
eigvals, eigvecs = np.linalg.eigh(cov)
# eigh sorts eigenvalues in ascending order
eigvals = np.maximum(eigvals, 1e-9) # Ensure positive eigenvalues
minor_std, major_std = np.sqrt(eigvals)
v_minor, v_major = eigvecs[:, 0], eigvecs[:, 1] # Corresponding eigenvectors
except np.linalg.LinAlgError:
# Fallback to a small circle if decomposition fails
radius = 0.1 * scale
t = np.linspace(0, 2 * np.pi, N)
x_p = cx + radius * np.cos(t)
y_p = cy + radius * np.sin(t)
else:
# Parametric equation: Center + scale * major_std * v_major * cos(t) + scale * minor_std * v_minor * sin(t)
# Note: We use scale*std which corresponds to half-axis lengths a,b used before
t = np.linspace(0, 2 * np.pi, N)
cos_t, sin_t = np.cos(t), np.sin(t)
x_p = cx + scale * (
major_std * cos_t * v_major[0] + minor_std * sin_t * v_minor[0]
)
y_p = cy + scale * (
major_std * cos_t * v_major[1] + minor_std * sin_t * v_minor[1]
)
# Build SVG path string
path = (
f"M {x_rot[0]},{y_rot[0]} "
+ " ".join(f"L{x_},{y_}" for x_, y_ in zip(x_rot[1:], y_rot[1:]))
f"M {x_p[0]},{y_p[0]} "
+ " ".join(f"L{x_},{y_}" for x_, y_ in zip(x_p[1:], y_p[1:]))
+ " Z"
)
return path
def gtsam_cov_to_plotly_ellipse(
cov_matrix: np.ndarray, scale: float = 2.0
) -> Tuple[float, float, float]:
"""Calculates ellipse angle (deg), width, height from 2x2 covariance."""
cov = cov_matrix[:2, :2] + np.eye(2) * 1e-9 # Ensure positive definite
try:
eigvals, eigvecs = np.linalg.eigh(cov)
eigvals = np.maximum(eigvals, 1e-9) # Ensure positive eigenvalues
except np.linalg.LinAlgError:
return 0, 0.1 * scale, 0.1 * scale # Default on failure
width = 2 * scale * np.sqrt(eigvals[1]) # Major axis (largest eigenvalue)
height = 2 * scale * np.sqrt(eigvals[0]) # Minor axis (smallest eigenvalue)
angle_rad = np.arctan2(
eigvecs[1, 1], eigvecs[0, 1]
) # Angle of major axis eigenvector
angle_deg = np.degrees(angle_rad)
return angle_deg, width, height
# --- Plotly Element Generators ---
def create_gt_landmarks_trace(landmarks_gt_array: np.ndarray) -> Optional[go.Scatter]:
def create_gt_landmarks_trace(landmarks_gt: np.ndarray) -> Optional[go.Scatter]:
"""Creates scatter trace for ground truth landmarks."""
if landmarks_gt_array is None or landmarks_gt_array.size == 0:
if landmarks_gt is None or landmarks_gt.size == 0:
return None
return go.Scatter(
x=landmarks_gt_array[0, :],
y=landmarks_gt_array[1, :],
x=landmarks_gt[0, :],
y=landmarks_gt[1, :],
mode="markers",
marker=dict(color="black", size=8, symbol="star"),
name="Landmarks GT",
@ -68,11 +83,9 @@ def create_gt_path_trace(poses_gt: List[gtsam.Pose2]) -> Optional[go.Scatter]:
"""Creates line trace for ground truth path."""
if not poses_gt:
return None
gt_path_x = [p.x() for p in poses_gt]
gt_path_y = [p.y() for p in poses_gt]
return go.Scatter(
x=gt_path_x,
y=gt_path_y,
x=[p.x() for p in poses_gt],
y=[p.y() for p in poses_gt],
mode="lines",
line=dict(color="gray", width=1, dash="dash"),
name="Path GT",
@ -82,21 +95,21 @@ def create_gt_path_trace(poses_gt: List[gtsam.Pose2]) -> Optional[go.Scatter]:
def create_est_path_trace(
est_path_x: List[float], est_path_y: List[float]
) -> go.Scatter:
"""Creates scatter/line trace for the estimated path up to current step."""
"""Creates trace for the estimated path."""
return go.Scatter(
x=est_path_x,
y=est_path_y,
mode="lines+markers",
line=dict(color="red", width=2),
marker=dict(size=4, color="red"),
name="Path Est", # This name applies to the trace in the specific frame
name="Path Est",
)
def create_est_landmarks_trace(
est_landmarks_x: List[float], est_landmarks_y: List[float]
) -> Optional[go.Scatter]:
"""Creates scatter trace for currently estimated landmarks."""
"""Creates trace for estimated landmarks."""
if not est_landmarks_x:
return None
return go.Scatter(
@ -104,136 +117,155 @@ def create_est_landmarks_trace(
y=est_landmarks_y,
mode="markers",
marker=dict(color="blue", size=6, symbol="x"),
name="Landmarks Est", # Applies to landmarks in the specific frame
name="Landmarks Est",
)
def _create_ellipse_shape_dict(
cx, cy, angle, width, height, fillcolor, line_color, name_suffix
cx, cy, cov, scale, fillcolor, line_color
) -> Dict[str, Any]:
"""Helper to create the dictionary for a Plotly ellipse shape."""
"""Helper: Creates dict for a Plotly ellipse shape from covariance."""
path = create_ellipse_path_from_cov(cx, cy, cov, scale)
return dict(
type="path",
path=ellipse_path(cx=cx, cy=cy, sizex=width, sizey=height, angle=angle),
path=path,
xref="x",
yref="y",
fillcolor=fillcolor,
line_color=line_color,
# name=f"{name_suffix} Cov", # Name isn't really used by Plotly for shapes
)
def create_pose_ellipse_shape(
pose_mean_xy: np.ndarray, pose_cov: np.ndarray, k: int, scale: float
pose_mean_xy: np.ndarray, pose_cov: np.ndarray, scale: float
) -> Dict[str, Any]:
"""Creates shape dictionary for a pose covariance ellipse."""
angle, width, height = gtsam_cov_to_plotly_ellipse(pose_cov, scale)
"""Creates shape dict for a pose covariance ellipse."""
return _create_ellipse_shape_dict(
cx=pose_mean_xy[0],
cy=pose_mean_xy[1],
angle=angle,
width=width,
height=height,
cov=pose_cov,
scale=scale,
fillcolor="rgba(255,0,255,0.2)",
line_color="rgba(255,0,255,0.5)",
name_suffix=f"Pose {k}",
)
def create_landmark_ellipse_shape(
lm_mean_xy: np.ndarray, lm_cov: np.ndarray, lm_index: int, scale: float
lm_mean_xy: np.ndarray, lm_cov: np.ndarray, scale: float
) -> Dict[str, Any]:
"""Creates shape dictionary for a landmark covariance ellipse."""
angle, width, height = gtsam_cov_to_plotly_ellipse(lm_cov, scale)
"""Creates shape dict for a landmark covariance ellipse."""
return _create_ellipse_shape_dict(
cx=lm_mean_xy[0],
cy=lm_mean_xy[1],
angle=angle,
width=width,
height=height,
cov=lm_cov,
scale=scale,
fillcolor="rgba(0,0,255,0.1)",
line_color="rgba(0,0,255,0.3)",
name_suffix=f"LM {lm_index}",
)
def dot_string_to_base64_png(
dot_string: Optional[str], engine: str = "neato"
) -> Optional[str]:
"""Renders a DOT string to a base64 encoded PNG using graphviz."""
if not dot_string:
return None
source = graphviz.Source(dot_string, engine="neato")
# Use pipe() to get bytes directly without saving a file
png_bytes = source.pipe(format="png")
encoded_string = base64.b64encode(png_bytes).decode()
return f"data:image/png;base64,{encoded_string}"
# --- Frame Content Generation ---
def generate_frame_content(
k: int,
step_results: gtsam.Values,
step_marginals: Optional[gtsam.Marginals],
frame_data: SlamFrameData,
X: Callable[[int], int],
L: Callable[[int], int],
max_landmark_index: int, # Need to know the potential range of landmarks
max_landmark_index: int,
ellipse_scale: float = 2.0,
graphviz_engine: str = "neato", # Added engine parameter
verbose: bool = False,
) -> Tuple[List[go.Scatter], List[Dict[str, Any]]]:
"""Generates all dynamic traces and shapes for a single animation frame `k`."""
) -> Tuple[List[go.Scatter], List[Dict[str, Any]], Optional[Dict[str, Any]]]:
"""Generates dynamic traces, shapes, and layout image for a single frame."""
k = frame_data.step_index
step_results = frame_data.results
step_marginals = frame_data.marginals
frame_traces: List[go.Scatter] = []
frame_shapes: List[Dict[str, Any]] = []
layout_image: Optional[Dict[str, Any]] = None
# 1. Gather Estimated Path Data
est_path_x = []
est_path_y = []
for i in range(k + 1):
pose_key = X(i)
if step_results.exists(pose_key):
pose = step_results.atPose2(pose_key)
est_path_x.append(pose.x())
est_path_y.append(pose.y())
# 1. Estimated Path (Unchanged)
est_path_x = [
step_results.atPose2(X(i)).x()
for i in range(k + 1)
if step_results.exists(X(i))
]
est_path_y = [
step_results.atPose2(X(i)).y()
for i in range(k + 1)
if step_results.exists(X(i))
]
frame_traces.append(create_est_path_trace(est_path_x, est_path_y))
# 2. Gather Estimated Landmark Data
est_landmarks_x = []
est_landmarks_y = []
landmark_keys_in_frame = []
# Check all potential landmark keys up to max_landmark_index
# 2. Estimated Landmarks (Unchanged)
est_landmarks_x, est_landmarks_y, landmark_keys = [], [], []
for j in range(max_landmark_index + 1):
lm_key = L(j)
if step_results.exists(lm_key):
lm_point = step_results.atPoint2(lm_key)
est_landmarks_x.append(lm_point[0])
est_landmarks_y.append(lm_point[1])
landmark_keys_in_frame.append(lm_key)
landmark_keys.append(lm_key)
lm_trace = create_est_landmarks_trace(est_landmarks_x, est_landmarks_y)
if lm_trace:
frame_traces.append(lm_trace)
# 3. Generate Covariance Ellipses (if marginals available)
if step_marginals is not None:
# Pose ellipse
current_pose_key = X(k)
if step_results.exists(current_pose_key):
# 3. Covariance Ellipses (Unchanged)
if step_marginals:
pose_key = X(k)
if step_results.exists(pose_key):
try:
pose_cov = step_marginals.marginalCovariance(current_pose_key)
pose_mean = step_results.atPose2(current_pose_key).translation()
frame_shapes.append(
create_pose_ellipse_shape(pose_mean, pose_cov, k, ellipse_scale)
)
cov = step_marginals.marginalCovariance(pose_key)
mean = step_results.atPose2(pose_key).translation()
frame_shapes.append(create_pose_ellipse_shape(mean, cov, ellipse_scale))
except Exception as e:
if verbose:
print(f"Warn: Pose {k} cov err @ step {k}: {e}")
# Landmark ellipses
for lm_key in landmark_keys_in_frame:
for lm_key in landmark_keys:
try:
lm_cov = step_marginals.marginalCovariance(lm_key)
lm_mean = step_results.atPoint2(lm_key)
lm_index = gtsam.Symbol(lm_key).index()
cov = step_marginals.marginalCovariance(lm_key)
mean = step_results.atPoint2(lm_key)
frame_shapes.append(
create_landmark_ellipse_shape(
lm_mean, lm_cov, lm_index, ellipse_scale
)
create_landmark_ellipse_shape(mean, cov, ellipse_scale)
)
except Exception as e:
lm_index = gtsam.Symbol(lm_key).index()
if verbose:
print(f"Warn: LM {lm_index} cov err @ step {k}: {e}")
print(
f"Warn: LM {gtsam.Symbol(lm_key).index()} cov err @ step {k}: {e}"
)
return frame_traces, frame_shapes
# 4. Graph Image for Layout (MODIFIED)
# Use the new function with the dot string from frame_data
img_src = dot_string_to_base64_png(
frame_data.graph_dot_string, engine=graphviz_engine
)
if img_src:
layout_image = dict(
source=img_src,
xref="paper",
yref="paper",
x=0,
y=1,
sizex=0.48,
sizey=1,
xanchor="left",
yanchor="top",
layer="below",
sizing="contain",
)
return frame_traces, frame_shapes, layout_image
# --- Figure Configuration ---
@ -244,9 +276,12 @@ def configure_figure_layout(
num_steps: int,
world_size: float,
initial_shapes: List[Dict[str, Any]],
initial_image: Optional[Dict[str, Any]],
) -> None:
"""Configures Plotly figure layout, axes, slider, buttons."""
"""Configures Plotly figure layout for side-by-side display."""
steps = list(range(num_steps + 1))
plot_domain = [0.52, 1.0] # Right pane for the SLAM plot
sliders = [
dict(
active=0,
@ -275,10 +310,10 @@ def configure_figure_layout(
showactive=False,
direction="left",
pad={"r": 10, "t": 87},
x=0.1,
xanchor="right",
x=plot_domain[0],
xanchor="left",
y=0,
yanchor="top",
yanchor="top", # Position relative to plot area
buttons=[
dict(
label="Play",
@ -310,28 +345,29 @@ def configure_figure_layout(
]
fig.update_layout(
title="Iterative Factor Graph SLAM Animation",
xaxis=dict(range=[-world_size / 2 - 2, world_size / 2 + 2], constrain="domain"),
title="Factor Graph SLAM Animation (Graph Left, Results Right)",
xaxis=dict(
range=[-world_size / 2 - 2, world_size / 2 + 2],
domain=plot_domain, # Confine axis to right pane
constrain="domain",
),
yaxis=dict(
range=[-world_size / 2 - 2, world_size / 2 + 2],
scaleanchor="x",
scaleratio=1,
domain=[0, 1], # Full height within its domain
),
width=800,
height=800,
width=1000,
height=600, # Adjust size for side-by-side
hovermode="closest",
updatemenus=updatemenus,
sliders=sliders,
shapes=initial_shapes,
legend=dict(
traceorder="reversed",
title_text="Legend",
orientation="h",
yanchor="bottom",
y=1.02,
xanchor="right",
x=1,
),
shapes=initial_shapes, # Initial shapes for the SLAM plot
images=(
[initial_image] if initial_image else []
), # Initial image for the left pane
showlegend=False, # Forego legend for space
# Autosize=True might help responsiveness but can interfere with domain
)
@ -339,23 +375,25 @@ def configure_figure_layout(
def create_slam_animation(
results_history: List[gtsam.Values],
marginals_history: List[Optional[gtsam.Marginals]],
num_steps: int,
history: List[SlamFrameData],
X: Callable[[int], int],
L: Callable[[int], int],
max_landmark_index: int, # Required to iterate potential landmarks
max_landmark_index: int,
landmarks_gt_array: Optional[np.ndarray] = None,
poses_gt: Optional[List[gtsam.Pose2]] = None,
world_size: float = 20.0,
ellipse_scale: float = 2.0,
graphviz_engine: str = "neato", # Pass engine choice
verbose_cov_errors: bool = False,
) -> go.Figure:
"""Creates a modular Plotly SLAM animation."""
"""Creates a side-by-side Plotly SLAM animation using a history of dataclasses."""
if not history:
raise ValueError("History cannot be empty.")
print("Generating Plotly animation...")
num_steps = history[-1].step_index
fig = go.Figure()
# 1. Add static ground truth traces to the base figure (visible always)
# 1. Add static GT traces (Unchanged)
gt_lm_trace = create_gt_landmarks_trace(landmarks_gt_array)
if gt_lm_trace:
fig.add_trace(gt_lm_trace)
@ -363,51 +401,58 @@ def create_slam_animation(
if gt_path_trace:
fig.add_trace(gt_path_trace)
# 2. Generate frames with dynamic content
# 2. Generate frames (MODIFIED: pass graphviz_engine)
frames = []
initial_dynamic_traces, initial_shapes, initial_image = [], [], None
steps_iterable = range(num_steps + 1)
try:
steps_iterable = tqdm(steps_iterable, desc="Creating Frames")
except NameError:
pass # tqdm optional
pass
for k in steps_iterable:
step_results = results_history[k]
step_marginals = marginals_history[k] if marginals_history else None
frame_data = next((item for item in history if item.step_index == k), None)
if frame_data is None:
print(f"Warning: Missing data for step {k} in history.")
continue
frame_traces, frame_shapes = generate_frame_content(
k,
step_results,
step_marginals,
# Pass the engine choice to the content generator
frame_traces, frame_shapes, layout_image = generate_frame_content(
frame_data,
X,
L,
max_landmark_index,
ellipse_scale,
graphviz_engine,
verbose_cov_errors,
)
if k == 0:
initial_dynamic_traces, initial_shapes, initial_image = (
frame_traces,
frame_shapes,
layout_image,
)
frames.append(
go.Frame(
data=frame_traces, name=str(k), layout=go.Layout(shapes=frame_shapes)
data=frame_traces,
name=str(k),
layout=go.Layout(
shapes=frame_shapes, images=[layout_image] if layout_image else []
),
)
)
# 3. Set initial dynamic data (from frame 0) onto the base figure
initial_dynamic_traces = []
initial_shapes = []
if frames:
# Important: Add *copies* or ensure traces are regenerated if needed,
# though Plotly usually handles this ok with frame data.
initial_dynamic_traces = frames[0].data
initial_shapes = frames[0].layout.shapes if frames[0].layout else []
for trace in initial_dynamic_traces:
fig.add_trace(trace) # Add Est Path[0], Est Landmarks[0] traces
# 3. Add initial dynamic traces (Unchanged)
for trace in initial_dynamic_traces:
fig.add_trace(trace)
# 4. Assign frames to the figure
# 4. Assign frames (Unchanged)
fig.update(frames=frames)
# 5. Configure layout, axes, controls
# Pass initial_shapes for the layout's starting state
configure_figure_layout(fig, num_steps, world_size, initial_shapes)
# 5. Configure layout (Unchanged)
configure_figure_layout(fig, num_steps, world_size, initial_shapes, initial_image)
print("Plotly animation generated.")
return fig