diff --git a/python/gtsam/examples/SFMExample_bal.py b/python/gtsam/examples/SFMExample_bal.py index 65b9e1334..77c186bd3 100644 --- a/python/gtsam/examples/SFMExample_bal.py +++ b/python/gtsam/examples/SFMExample_bal.py @@ -16,36 +16,37 @@ import sys import gtsam from gtsam import (GeneralSFMFactorCal3Bundler, - PriorFactorPinholeCameraCal3Bundler, PriorFactorPoint3, - readBal) + PriorFactorPinholeCameraCal3Bundler, PriorFactorPoint3) from gtsam.symbol_shorthand import C, P -from gtsam.utils.plot import plot_3d_points, plot_trajectory +from gtsam.utils import plot +from matplotlib import pyplot as plt logging.basicConfig(stream=sys.stdout, level=logging.INFO) +DEFAULT_BAL_DATASET = "dubrovnik-3-7-pre" -def plot(scene_data: gtsam.SfmData, result: gtsam.Values): - """Plot the trajectory.""" + +def plot_scene(scene_data: gtsam.SfmData, result: gtsam.Values): + """Plot the SFM results.""" plot_vals = gtsam.Values() for cam_idx in range(scene_data.number_cameras()): plot_vals.insert(C(cam_idx), result.atPinholeCameraCal3Bundler(C(cam_idx)).pose()) - for t_idx in range(scene_data.number_tracks()): - plot_vals.insert(P(t_idx), result.atPoint3(P(t_idx))) + for j in range(scene_data.number_tracks()): + plot_vals.insert(P(j), result.atPoint3(P(j))) - plot_3d_points(0, plot_vals, linespec="g.") - plot_trajectory(0, plot_vals, show=True) + plot.plot_3d_points(0, plot_vals, linespec="g.") + plot.plot_trajectory(0, plot_vals, title="SFM results") + + plt.show() def run(args: argparse.Namespace): """ Run LM optimization with BAL input data and report resulting error """ - if args.input_file: - input_file = args.input_file - else: - input_file = gtsam.findExampleDataFile("dubrovnik-3-7-pre") + input_file = args.input_file # Load the SfM data from file - scene_data = readBal(input_file) + scene_data = gtsam.readBal(input_file) logging.info("read %d tracks on %d cameras\n", scene_data.number_tracks(), scene_data.number_cameras()) @@ -56,16 +57,14 @@ def run(args: argparse.Namespace): noise = gtsam.noiseModel.Isotropic.Sigma(2, 1.0) # one pixel in u and v # Add measurements to the factor graph - j = 0 - for t_idx in range(scene_data.number_tracks()): - track = scene_data.track(t_idx) # SfmTrack + for j in range(scene_data.number_tracks()): + track = scene_data.track(j) # SfmTrack # retrieve the SfmMeasurement objects for m_idx in range(track.number_measurements()): # i represents the camera index, and uv is the 2d measurement i, uv = track.measurement(m_idx) # note use of shorthand symbols C and P graph.add(GeneralSFMFactorCal3Bundler(uv, noise, C(i), P(j))) - j += 1 # Add a prior on pose x1. This indirectly specifies where the origin is. graph.push_back( @@ -89,9 +88,9 @@ def run(args: argparse.Namespace): i += 1 # add each SfmTrack - for t_idx in range(scene_data.number_tracks()): - track = scene_data.track(t_idx) - initial.insert(P(t_idx), track.point3()) + for j in range(scene_data.number_tracks()): + track = scene_data.track(j) + initial.insert(P(j), track.point3()) # Optimize the graph and print results try: @@ -107,7 +106,7 @@ def run(args: argparse.Namespace): logging.info("initial error: %f", graph.error(initial)) logging.info("final error: %f", graph.error(result)) - plot(scene_data, result) + plot_scene(scene_data, result) def main(): @@ -116,7 +115,7 @@ def main(): parser.add_argument('-i', '--input_file', type=str, - default="", + default=gtsam.findExampleDataFile(DEFAULT_BAL_DATASET), help="""Read SFM data from the specified BAL file. The data format is described here: https://grail.cs.washington.edu/projects/bal/. BAL files contain (nrPoses, nrPoints, nrObservations), followed by (i,j,u,v) tuples,