From e4278687b412c60e78bfc50c0b2c396b38553920 Mon Sep 17 00:00:00 2001 From: Frank Dellaert Date: Thu, 24 Apr 2025 16:42:23 -0400 Subject: [PATCH] Plot side by side --- python/gtsam/examples/EKF_SLAM.ipynb | 63 ++--- python/gtsam/examples/gtsam_plotly.py | 357 +++++++++++++++----------- 2 files changed, 221 insertions(+), 199 deletions(-) diff --git a/python/gtsam/examples/EKF_SLAM.ipynb b/python/gtsam/examples/EKF_SLAM.ipynb index 84e69e1a1..265005710 100644 --- a/python/gtsam/examples/EKF_SLAM.ipynb +++ b/python/gtsam/examples/EKF_SLAM.ipynb @@ -106,8 +106,6 @@ "outputs": [], "source": [ "import numpy as np\n", - "import matplotlib.pyplot as plt # For initial plot if desired (optional)\n", - "import plotly.graph_objects as go\n", "from tqdm.notebook import tqdm # Progress bar\n", "import math\n", "import time # To slow down graphviz display if needed\n", @@ -117,11 +115,10 @@ "\n", "# Imports for new modules\n", "import simulation\n", - "import gtsam_plotly \n", + "from gtsam_plotly import SlamFrameData, create_slam_animation\n", "\n", "# Imports for graph visualization\n", - "import graphviz\n", - "from IPython.display import display, clear_output" + "import graphviz" ] }, { @@ -158,11 +155,7 @@ "MEASUREMENT_NOISE = gtsam.noiseModel.Diagonal.Sigmas(np.array([np.deg2rad(2.0), 0.2]))\n", "\n", "# Sensor parameters\n", - "MAX_SENSOR_RANGE = 5.0\n", - "\n", - "# Visualization parameters\n", - "SHOW_GRAPHVIZ_EACH_STEP = True # Set to False to only show final graph\n", - "GRAPHVIZ_PAUSE_SECONDS = 1 # Pause briefly to allow viewing the graph" + "MAX_SENSOR_RANGE = 5.0" ] }, { @@ -215,15 +208,11 @@ "metadata": {}, "outputs": [], "source": [ - "writer = gtsam.GraphvizFormatting()\n", - "writer.binaryEdges = True\n", + "WRITER = gtsam.GraphvizFormatting()\n", + "WRITER.binaryEdges = True\n", "\n", - "def maybe_show(graph, estimate, message):\n", - " if SHOW_GRAPHVIZ_EACH_STEP:\n", - " print(message)\n", - " graph_dot = graph.dot(estimate, writer=writer)\n", - " display(graphviz.Source(graph_dot, engine='neato'))\n", - " time.sleep(GRAPHVIZ_PAUSE_SECONDS)" + "def mage_dot(graph, estimate):\n", + " return graph.dot(estimate, writer=WRITER)\n" ] }, { @@ -248,20 +237,13 @@ "current_graph.add(gtsam.PriorFactorPose2(current_pose_key, initial_pose, PRIOR_NOISE))\n", "\n", "# Variables to store results for animation\n", - "results_history = [gtsam.Values(current_estimate)] # Store Values object at each step\n", - "marginals_history = [] # Store Marginals object at each step\n", + "history = [] # Store SLAMFrameData objects for each step\n", "known_landmarks = set() # Set of landmark keys L(j) currently in the state\n", "\n", "# Initial marginals (only for X(0))\n", - "try:\n", - " initial_gaussian_graph = current_graph.linearize(current_estimate)\n", - " initial_marginals = gtsam.Marginals(initial_gaussian_graph, current_estimate)\n", - " marginals_history.append(initial_marginals)\n", - "except Exception as e:\n", - " print(f\"Error computing initial marginals: {e}\")\n", - " marginals_history.append(None) # Append placeholder if fails\n", - "\n", - "maybe_show(current_graph, current_estimate, \"--- Step 0 --- Initial Graph with Prior ---\")" + "initial_gaussian_graph = current_graph.linearize(current_estimate)\n", + "initial_marginals = gtsam.Marginals(initial_gaussian_graph, current_estimate)\n", + "history.append(SlamFrameData(0, current_estimate, initial_marginals, mage_dot(current_graph, current_estimate)))" ] }, { @@ -324,16 +306,12 @@ " optimizer = gtsam.LevenbergMarquardtOptimizer(current_graph, current_estimate)\n", " current_estimate = optimizer.optimize() # Update estimates by optimizing the whole graph so far\n", " \n", - " # --- Display Factor Graph --- \n", - " maybe_show(current_graph, current_estimate, f\"--- Step {k+1} --- Factor Graph ---\")\n", - "\n", - " # Store results for animation\n", - " results_history.append(gtsam.Values(current_estimate)) # Store a copy\n", - " \n", " # Calculate marginals for visualization (can be slow for large graphs)\n", " current_gaussian_graph = current_graph.linearize(current_estimate)\n", " current_marginals = gtsam.Marginals(current_gaussian_graph, current_estimate)\n", - " marginals_history.append(current_marginals)\n", + "\n", + " # Store the current state for visualization\n", + " history.append(SlamFrameData(k+1, current_estimate, current_marginals, mage_dot(current_graph, current_estimate)))\n", "\n", "print(\"\\nIterative Factor Graph SLAM finished.\")\n", "print(f\"Final number of poses estimated: {len([key for key in current_estimate.keys() if gtsam.Symbol(key).chr() == ord('x')])}\")\n", @@ -360,15 +338,14 @@ }, "outputs": [], "source": [ - "fig = gtsam_plotly.create_slam_animation(\n", - " results_history=results_history,\n", - " marginals_history=marginals_history,\n", + "fig = create_slam_animation(\n", + " history,\n", " landmarks_gt_array=landmarks_gt_array,\n", " poses_gt=poses_gt,\n", - " num_steps=NUM_STEPS,\n", " world_size=WORLD_SIZE,\n", - " X=X, # Pass symbol functions\n", - " L=L\n", + " X=X, # Pass symbol functions\n", + " L=L,\n", + " max_landmark_index=NUM_LANDMARKS,\n", ")\n", "\n", "print(\"Displaying animation...\")\n", @@ -384,7 +361,7 @@ "\n", "* **Approach:** This notebook implemented SLAM iteratively using GTSAM factor graphs. At each step, new factors (odometry, measurements) were added, and the graph was re-optimized using Levenberg-Marquardt. This is more akin to **incremental smoothing** than a classic filter (which would explicitly marginalize past states).\n", "* **Modularity:** Simulation and Plotly visualization code have been moved into separate `simulation.py` and `gtsam_plotly.py` files for better organization.\n", - "* **Graph Visualization:** The `graphviz` library was used to render the factor graph at each step (or only at the end, depending on `SHOW_GRAPHVIZ_EACH_STEP`). This helps visualize how the graph structure grows and connects poses and landmarks.\n", + "* **Graph Visualization:** The `graphviz` library was used to render the factor graph at each step. This helps visualize how the graph structure grows and connects poses and landmarks.\n", "* **Efficiency:** Optimizing the entire graph at every step becomes computationally expensive for long trajectories. For real-time performance or large problems, **iSAM2 (Incremental Smoothing and Mapping)** is the preferred GTSAM algorithm. iSAM2 efficiently updates the solution by only re-linearizing and re-solving parts of the graph affected by new measurements.\n", "* **Accuracy vs. EKF:** This factor graph approach generally handles non-linearities better than a standard EKF because it re-linearizes during optimization. It avoids some of the consistency pitfalls associated with the EKF's single linearization point per step.\n", "* **Visualization:** The Plotly animation shows the evolution of the robot's path estimate, the map of landmarks, and their associated uncertainties (covariance ellipses). You can observe how measurements help refine the estimates and reduce uncertainty, especially when loops are closed (implicitly here through repeated observations of landmarks)." diff --git a/python/gtsam/examples/gtsam_plotly.py b/python/gtsam/examples/gtsam_plotly.py index d37e32dc9..dbbae8f46 100644 --- a/python/gtsam/examples/gtsam_plotly.py +++ b/python/gtsam/examples/gtsam_plotly.py @@ -1,63 +1,78 @@ -# gtsam_plotly_modular_v2.py +# gtsam_plotly_modular_v3.py +import base64 +from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Tuple +import graphviz import numpy as np import plotly.graph_objects as go from tqdm.notebook import tqdm import gtsam -# --- Core Ellipse Calculations --- + +# --- Dataclass for History --- +@dataclass +class SlamFrameData: + """Holds all data needed for a single frame of the SLAM animation.""" + + step_index: int + results: gtsam.Values + marginals: Optional[gtsam.Marginals] + graph_dot_string: Optional[str] = None # Store the Graphviz DOT string -def ellipse_path( - cx: float, cy: float, sizex: float, sizey: float, angle: float, N: int = 60 +# --- Core Ellipse Calculation & Path Generation --- + + +def create_ellipse_path_from_cov( + cx: float, cy: float, cov_matrix: np.ndarray, scale: float = 2.0, N: int = 60 ) -> str: - """Generates SVG path string for a rotated ellipse.""" - angle_rad = np.radians(angle) - t = np.linspace(0, 2 * np.pi, N) - x_unit = (sizex / 2) * np.cos(t) - y_unit = (sizey / 2) * np.sin(t) - x_rot = cx + x_unit * np.cos(angle_rad) - y_unit * np.sin(angle_rad) - y_rot = cy + x_unit * np.sin(angle_rad) + y_unit * np.cos(angle_rad) + """Generates SVG path string for an ellipse directly from 2x2 covariance.""" + cov = cov_matrix[:2, :2] + np.eye(2) * 1e-9 # Ensure positive definite + try: + eigvals, eigvecs = np.linalg.eigh(cov) + # eigh sorts eigenvalues in ascending order + eigvals = np.maximum(eigvals, 1e-9) # Ensure positive eigenvalues + minor_std, major_std = np.sqrt(eigvals) + v_minor, v_major = eigvecs[:, 0], eigvecs[:, 1] # Corresponding eigenvectors + except np.linalg.LinAlgError: + # Fallback to a small circle if decomposition fails + radius = 0.1 * scale + t = np.linspace(0, 2 * np.pi, N) + x_p = cx + radius * np.cos(t) + y_p = cy + radius * np.sin(t) + else: + # Parametric equation: Center + scale * major_std * v_major * cos(t) + scale * minor_std * v_minor * sin(t) + # Note: We use scale*std which corresponds to half-axis lengths a,b used before + t = np.linspace(0, 2 * np.pi, N) + cos_t, sin_t = np.cos(t), np.sin(t) + x_p = cx + scale * ( + major_std * cos_t * v_major[0] + minor_std * sin_t * v_minor[0] + ) + y_p = cy + scale * ( + major_std * cos_t * v_major[1] + minor_std * sin_t * v_minor[1] + ) + + # Build SVG path string path = ( - f"M {x_rot[0]},{y_rot[0]} " - + " ".join(f"L{x_},{y_}" for x_, y_ in zip(x_rot[1:], y_rot[1:])) + f"M {x_p[0]},{y_p[0]} " + + " ".join(f"L{x_},{y_}" for x_, y_ in zip(x_p[1:], y_p[1:])) + " Z" ) return path -def gtsam_cov_to_plotly_ellipse( - cov_matrix: np.ndarray, scale: float = 2.0 -) -> Tuple[float, float, float]: - """Calculates ellipse angle (deg), width, height from 2x2 covariance.""" - cov = cov_matrix[:2, :2] + np.eye(2) * 1e-9 # Ensure positive definite - try: - eigvals, eigvecs = np.linalg.eigh(cov) - eigvals = np.maximum(eigvals, 1e-9) # Ensure positive eigenvalues - except np.linalg.LinAlgError: - return 0, 0.1 * scale, 0.1 * scale # Default on failure - - width = 2 * scale * np.sqrt(eigvals[1]) # Major axis (largest eigenvalue) - height = 2 * scale * np.sqrt(eigvals[0]) # Minor axis (smallest eigenvalue) - angle_rad = np.arctan2( - eigvecs[1, 1], eigvecs[0, 1] - ) # Angle of major axis eigenvector - angle_deg = np.degrees(angle_rad) - return angle_deg, width, height - - # --- Plotly Element Generators --- -def create_gt_landmarks_trace(landmarks_gt_array: np.ndarray) -> Optional[go.Scatter]: +def create_gt_landmarks_trace(landmarks_gt: np.ndarray) -> Optional[go.Scatter]: """Creates scatter trace for ground truth landmarks.""" - if landmarks_gt_array is None or landmarks_gt_array.size == 0: + if landmarks_gt is None or landmarks_gt.size == 0: return None return go.Scatter( - x=landmarks_gt_array[0, :], - y=landmarks_gt_array[1, :], + x=landmarks_gt[0, :], + y=landmarks_gt[1, :], mode="markers", marker=dict(color="black", size=8, symbol="star"), name="Landmarks GT", @@ -68,11 +83,9 @@ def create_gt_path_trace(poses_gt: List[gtsam.Pose2]) -> Optional[go.Scatter]: """Creates line trace for ground truth path.""" if not poses_gt: return None - gt_path_x = [p.x() for p in poses_gt] - gt_path_y = [p.y() for p in poses_gt] return go.Scatter( - x=gt_path_x, - y=gt_path_y, + x=[p.x() for p in poses_gt], + y=[p.y() for p in poses_gt], mode="lines", line=dict(color="gray", width=1, dash="dash"), name="Path GT", @@ -82,21 +95,21 @@ def create_gt_path_trace(poses_gt: List[gtsam.Pose2]) -> Optional[go.Scatter]: def create_est_path_trace( est_path_x: List[float], est_path_y: List[float] ) -> go.Scatter: - """Creates scatter/line trace for the estimated path up to current step.""" + """Creates trace for the estimated path.""" return go.Scatter( x=est_path_x, y=est_path_y, mode="lines+markers", line=dict(color="red", width=2), marker=dict(size=4, color="red"), - name="Path Est", # This name applies to the trace in the specific frame + name="Path Est", ) def create_est_landmarks_trace( est_landmarks_x: List[float], est_landmarks_y: List[float] ) -> Optional[go.Scatter]: - """Creates scatter trace for currently estimated landmarks.""" + """Creates trace for estimated landmarks.""" if not est_landmarks_x: return None return go.Scatter( @@ -104,136 +117,155 @@ def create_est_landmarks_trace( y=est_landmarks_y, mode="markers", marker=dict(color="blue", size=6, symbol="x"), - name="Landmarks Est", # Applies to landmarks in the specific frame + name="Landmarks Est", ) def _create_ellipse_shape_dict( - cx, cy, angle, width, height, fillcolor, line_color, name_suffix + cx, cy, cov, scale, fillcolor, line_color ) -> Dict[str, Any]: - """Helper to create the dictionary for a Plotly ellipse shape.""" + """Helper: Creates dict for a Plotly ellipse shape from covariance.""" + path = create_ellipse_path_from_cov(cx, cy, cov, scale) return dict( type="path", - path=ellipse_path(cx=cx, cy=cy, sizex=width, sizey=height, angle=angle), + path=path, xref="x", yref="y", fillcolor=fillcolor, line_color=line_color, - # name=f"{name_suffix} Cov", # Name isn't really used by Plotly for shapes ) def create_pose_ellipse_shape( - pose_mean_xy: np.ndarray, pose_cov: np.ndarray, k: int, scale: float + pose_mean_xy: np.ndarray, pose_cov: np.ndarray, scale: float ) -> Dict[str, Any]: - """Creates shape dictionary for a pose covariance ellipse.""" - angle, width, height = gtsam_cov_to_plotly_ellipse(pose_cov, scale) + """Creates shape dict for a pose covariance ellipse.""" return _create_ellipse_shape_dict( cx=pose_mean_xy[0], cy=pose_mean_xy[1], - angle=angle, - width=width, - height=height, + cov=pose_cov, + scale=scale, fillcolor="rgba(255,0,255,0.2)", line_color="rgba(255,0,255,0.5)", - name_suffix=f"Pose {k}", ) def create_landmark_ellipse_shape( - lm_mean_xy: np.ndarray, lm_cov: np.ndarray, lm_index: int, scale: float + lm_mean_xy: np.ndarray, lm_cov: np.ndarray, scale: float ) -> Dict[str, Any]: - """Creates shape dictionary for a landmark covariance ellipse.""" - angle, width, height = gtsam_cov_to_plotly_ellipse(lm_cov, scale) + """Creates shape dict for a landmark covariance ellipse.""" return _create_ellipse_shape_dict( cx=lm_mean_xy[0], cy=lm_mean_xy[1], - angle=angle, - width=width, - height=height, + cov=lm_cov, + scale=scale, fillcolor="rgba(0,0,255,0.1)", line_color="rgba(0,0,255,0.3)", - name_suffix=f"LM {lm_index}", ) +def dot_string_to_base64_png( + dot_string: Optional[str], engine: str = "neato" +) -> Optional[str]: + """Renders a DOT string to a base64 encoded PNG using graphviz.""" + if not dot_string: + return None + source = graphviz.Source(dot_string, engine="neato") + # Use pipe() to get bytes directly without saving a file + png_bytes = source.pipe(format="png") + encoded_string = base64.b64encode(png_bytes).decode() + return f"data:image/png;base64,{encoded_string}" + # --- Frame Content Generation --- - - def generate_frame_content( - k: int, - step_results: gtsam.Values, - step_marginals: Optional[gtsam.Marginals], + frame_data: SlamFrameData, X: Callable[[int], int], L: Callable[[int], int], - max_landmark_index: int, # Need to know the potential range of landmarks + max_landmark_index: int, ellipse_scale: float = 2.0, + graphviz_engine: str = "neato", # Added engine parameter verbose: bool = False, -) -> Tuple[List[go.Scatter], List[Dict[str, Any]]]: - """Generates all dynamic traces and shapes for a single animation frame `k`.""" +) -> Tuple[List[go.Scatter], List[Dict[str, Any]], Optional[Dict[str, Any]]]: + """Generates dynamic traces, shapes, and layout image for a single frame.""" + k = frame_data.step_index + step_results = frame_data.results + step_marginals = frame_data.marginals + frame_traces: List[go.Scatter] = [] frame_shapes: List[Dict[str, Any]] = [] + layout_image: Optional[Dict[str, Any]] = None - # 1. Gather Estimated Path Data - est_path_x = [] - est_path_y = [] - for i in range(k + 1): - pose_key = X(i) - if step_results.exists(pose_key): - pose = step_results.atPose2(pose_key) - est_path_x.append(pose.x()) - est_path_y.append(pose.y()) + # 1. Estimated Path (Unchanged) + est_path_x = [ + step_results.atPose2(X(i)).x() + for i in range(k + 1) + if step_results.exists(X(i)) + ] + est_path_y = [ + step_results.atPose2(X(i)).y() + for i in range(k + 1) + if step_results.exists(X(i)) + ] frame_traces.append(create_est_path_trace(est_path_x, est_path_y)) - # 2. Gather Estimated Landmark Data - est_landmarks_x = [] - est_landmarks_y = [] - landmark_keys_in_frame = [] - # Check all potential landmark keys up to max_landmark_index + # 2. Estimated Landmarks (Unchanged) + est_landmarks_x, est_landmarks_y, landmark_keys = [], [], [] for j in range(max_landmark_index + 1): lm_key = L(j) if step_results.exists(lm_key): lm_point = step_results.atPoint2(lm_key) est_landmarks_x.append(lm_point[0]) est_landmarks_y.append(lm_point[1]) - landmark_keys_in_frame.append(lm_key) - + landmark_keys.append(lm_key) lm_trace = create_est_landmarks_trace(est_landmarks_x, est_landmarks_y) if lm_trace: frame_traces.append(lm_trace) - # 3. Generate Covariance Ellipses (if marginals available) - if step_marginals is not None: - # Pose ellipse - current_pose_key = X(k) - if step_results.exists(current_pose_key): + # 3. Covariance Ellipses (Unchanged) + if step_marginals: + pose_key = X(k) + if step_results.exists(pose_key): try: - pose_cov = step_marginals.marginalCovariance(current_pose_key) - pose_mean = step_results.atPose2(current_pose_key).translation() - frame_shapes.append( - create_pose_ellipse_shape(pose_mean, pose_cov, k, ellipse_scale) - ) + cov = step_marginals.marginalCovariance(pose_key) + mean = step_results.atPose2(pose_key).translation() + frame_shapes.append(create_pose_ellipse_shape(mean, cov, ellipse_scale)) except Exception as e: if verbose: print(f"Warn: Pose {k} cov err @ step {k}: {e}") - - # Landmark ellipses - for lm_key in landmark_keys_in_frame: + for lm_key in landmark_keys: try: - lm_cov = step_marginals.marginalCovariance(lm_key) - lm_mean = step_results.atPoint2(lm_key) - lm_index = gtsam.Symbol(lm_key).index() + cov = step_marginals.marginalCovariance(lm_key) + mean = step_results.atPoint2(lm_key) frame_shapes.append( - create_landmark_ellipse_shape( - lm_mean, lm_cov, lm_index, ellipse_scale - ) + create_landmark_ellipse_shape(mean, cov, ellipse_scale) ) except Exception as e: - lm_index = gtsam.Symbol(lm_key).index() if verbose: - print(f"Warn: LM {lm_index} cov err @ step {k}: {e}") + print( + f"Warn: LM {gtsam.Symbol(lm_key).index()} cov err @ step {k}: {e}" + ) - return frame_traces, frame_shapes + # 4. Graph Image for Layout (MODIFIED) + # Use the new function with the dot string from frame_data + img_src = dot_string_to_base64_png( + frame_data.graph_dot_string, engine=graphviz_engine + ) + if img_src: + layout_image = dict( + source=img_src, + xref="paper", + yref="paper", + x=0, + y=1, + sizex=0.48, + sizey=1, + xanchor="left", + yanchor="top", + layer="below", + sizing="contain", + ) + + return frame_traces, frame_shapes, layout_image # --- Figure Configuration --- @@ -244,9 +276,12 @@ def configure_figure_layout( num_steps: int, world_size: float, initial_shapes: List[Dict[str, Any]], + initial_image: Optional[Dict[str, Any]], ) -> None: - """Configures Plotly figure layout, axes, slider, buttons.""" + """Configures Plotly figure layout for side-by-side display.""" steps = list(range(num_steps + 1)) + plot_domain = [0.52, 1.0] # Right pane for the SLAM plot + sliders = [ dict( active=0, @@ -275,10 +310,10 @@ def configure_figure_layout( showactive=False, direction="left", pad={"r": 10, "t": 87}, - x=0.1, - xanchor="right", + x=plot_domain[0], + xanchor="left", y=0, - yanchor="top", + yanchor="top", # Position relative to plot area buttons=[ dict( label="Play", @@ -310,28 +345,29 @@ def configure_figure_layout( ] fig.update_layout( - title="Iterative Factor Graph SLAM Animation", - xaxis=dict(range=[-world_size / 2 - 2, world_size / 2 + 2], constrain="domain"), + title="Factor Graph SLAM Animation (Graph Left, Results Right)", + xaxis=dict( + range=[-world_size / 2 - 2, world_size / 2 + 2], + domain=plot_domain, # Confine axis to right pane + constrain="domain", + ), yaxis=dict( range=[-world_size / 2 - 2, world_size / 2 + 2], scaleanchor="x", scaleratio=1, + domain=[0, 1], # Full height within its domain ), - width=800, - height=800, + width=1000, + height=600, # Adjust size for side-by-side hovermode="closest", updatemenus=updatemenus, sliders=sliders, - shapes=initial_shapes, - legend=dict( - traceorder="reversed", - title_text="Legend", - orientation="h", - yanchor="bottom", - y=1.02, - xanchor="right", - x=1, - ), + shapes=initial_shapes, # Initial shapes for the SLAM plot + images=( + [initial_image] if initial_image else [] + ), # Initial image for the left pane + showlegend=False, # Forego legend for space + # Autosize=True might help responsiveness but can interfere with domain ) @@ -339,23 +375,25 @@ def configure_figure_layout( def create_slam_animation( - results_history: List[gtsam.Values], - marginals_history: List[Optional[gtsam.Marginals]], - num_steps: int, + history: List[SlamFrameData], X: Callable[[int], int], L: Callable[[int], int], - max_landmark_index: int, # Required to iterate potential landmarks + max_landmark_index: int, landmarks_gt_array: Optional[np.ndarray] = None, poses_gt: Optional[List[gtsam.Pose2]] = None, world_size: float = 20.0, ellipse_scale: float = 2.0, + graphviz_engine: str = "neato", # Pass engine choice verbose_cov_errors: bool = False, ) -> go.Figure: - """Creates a modular Plotly SLAM animation.""" + """Creates a side-by-side Plotly SLAM animation using a history of dataclasses.""" + if not history: + raise ValueError("History cannot be empty.") print("Generating Plotly animation...") + num_steps = history[-1].step_index fig = go.Figure() - # 1. Add static ground truth traces to the base figure (visible always) + # 1. Add static GT traces (Unchanged) gt_lm_trace = create_gt_landmarks_trace(landmarks_gt_array) if gt_lm_trace: fig.add_trace(gt_lm_trace) @@ -363,51 +401,58 @@ def create_slam_animation( if gt_path_trace: fig.add_trace(gt_path_trace) - # 2. Generate frames with dynamic content + # 2. Generate frames (MODIFIED: pass graphviz_engine) frames = [] + initial_dynamic_traces, initial_shapes, initial_image = [], [], None steps_iterable = range(num_steps + 1) try: steps_iterable = tqdm(steps_iterable, desc="Creating Frames") except NameError: - pass # tqdm optional + pass for k in steps_iterable: - step_results = results_history[k] - step_marginals = marginals_history[k] if marginals_history else None + frame_data = next((item for item in history if item.step_index == k), None) + if frame_data is None: + print(f"Warning: Missing data for step {k} in history.") + continue - frame_traces, frame_shapes = generate_frame_content( - k, - step_results, - step_marginals, + # Pass the engine choice to the content generator + frame_traces, frame_shapes, layout_image = generate_frame_content( + frame_data, X, L, max_landmark_index, ellipse_scale, + graphviz_engine, verbose_cov_errors, ) + + if k == 0: + initial_dynamic_traces, initial_shapes, initial_image = ( + frame_traces, + frame_shapes, + layout_image, + ) + frames.append( go.Frame( - data=frame_traces, name=str(k), layout=go.Layout(shapes=frame_shapes) + data=frame_traces, + name=str(k), + layout=go.Layout( + shapes=frame_shapes, images=[layout_image] if layout_image else [] + ), ) ) - # 3. Set initial dynamic data (from frame 0) onto the base figure - initial_dynamic_traces = [] - initial_shapes = [] - if frames: - # Important: Add *copies* or ensure traces are regenerated if needed, - # though Plotly usually handles this ok with frame data. - initial_dynamic_traces = frames[0].data - initial_shapes = frames[0].layout.shapes if frames[0].layout else [] - for trace in initial_dynamic_traces: - fig.add_trace(trace) # Add Est Path[0], Est Landmarks[0] traces + # 3. Add initial dynamic traces (Unchanged) + for trace in initial_dynamic_traces: + fig.add_trace(trace) - # 4. Assign frames to the figure + # 4. Assign frames (Unchanged) fig.update(frames=frames) - # 5. Configure layout, axes, controls - # Pass initial_shapes for the layout's starting state - configure_figure_layout(fig, num_steps, world_size, initial_shapes) + # 5. Configure layout (Unchanged) + configure_figure_layout(fig, num_steps, world_size, initial_shapes, initial_image) print("Plotly animation generated.") return fig