format and refactor the SFM BAL example

release/4.3a0
Varun Agrawal 2021-11-09 18:19:47 -05:00
parent a634a91c1a
commit 1bcb44784a
1 changed files with 59 additions and 48 deletions

View File

@ -7,49 +7,58 @@
See LICENSE for the license information See LICENSE for the license information
Solve a structure-from-motion problem from a "Bundle Adjustment in the Large" file Solve a structure-from-motion problem from a "Bundle Adjustment in the Large" file
Author: Frank Dellaert (Python: Akshay Krishnan, John Lambert) Author: Frank Dellaert (Python: Akshay Krishnan, John Lambert, Varun Agrawal)
""" """
import argparse import argparse
import logging import logging
import sys import sys
import matplotlib.pyplot as plt
import numpy as np
import gtsam import gtsam
from gtsam import ( from gtsam import (GeneralSFMFactorCal3Bundler,
GeneralSFMFactorCal3Bundler, PriorFactorPinholeCameraCal3Bundler, PriorFactorPoint3,
PinholeCameraCal3Bundler, readBal)
PriorFactorPinholeCameraCal3Bundler, from gtsam.symbol_shorthand import C, P
readBal, from gtsam.utils.plot import plot_3d_points, plot_trajectory
symbol_shorthand
)
C = symbol_shorthand.C logging.basicConfig(stream=sys.stdout, level=logging.INFO)
P = symbol_shorthand.P
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG) def plot(scene_data: gtsam.SfmData, result: gtsam.Values):
"""Plot the trajectory."""
plot_vals = gtsam.Values()
for cam_idx in range(scene_data.number_cameras()):
plot_vals.insert(C(cam_idx),
result.atPinholeCameraCal3Bundler(C(cam_idx)).pose())
for t_idx in range(scene_data.number_tracks()):
plot_vals.insert(P(t_idx), result.atPoint3(P(t_idx)))
def run(args): plot_3d_points(0, plot_vals, linespec="g.")
plot_trajectory(0, plot_vals, show=True)
def run(args: argparse.Namespace):
""" Run LM optimization with BAL input data and report resulting error """ """ Run LM optimization with BAL input data and report resulting error """
input_file = gtsam.findExampleDataFile(args.input_file) if args.input_file:
input_file = args.input_file
else:
input_file = gtsam.findExampleDataFile("dubrovnik-3-7-pre")
# Load the SfM data from file # Load the SfM data from file
scene_data = readBal(input_file) scene_data = readBal(input_file)
logging.info(f"read {scene_data.number_tracks()} tracks on {scene_data.number_cameras()} cameras\n") logging.info("read %d tracks on %d cameras\n", scene_data.number_tracks(),
scene_data.number_cameras())
# Create a factor graph # Create a factor graph
graph = gtsam.NonlinearFactorGraph() graph = gtsam.NonlinearFactorGraph()
# We share *one* noiseModel between all projection factors # We share *one* noiseModel between all projection factors
noise = gtsam.noiseModel.Isotropic.Sigma(2, 1.0) # one pixel in u and v noise = gtsam.noiseModel.Isotropic.Sigma(2, 1.0) # one pixel in u and v
# Add measurements to the factor graph # Add measurements to the factor graph
j = 0 j = 0
for t_idx in range(scene_data.number_tracks()): for t_idx in range(scene_data.number_tracks()):
track = scene_data.track(t_idx) # SfmTrack track = scene_data.track(t_idx) # SfmTrack
# retrieve the SfmMeasurement objects # retrieve the SfmMeasurement objects
for m_idx in range(track.number_measurements()): for m_idx in range(track.number_measurements()):
# i represents the camera index, and uv is the 2d measurement # i represents the camera index, and uv is the 2d measurement
@ -60,20 +69,18 @@ def run(args):
# Add a prior on pose x1. This indirectly specifies where the origin is. # Add a prior on pose x1. This indirectly specifies where the origin is.
graph.push_back( graph.push_back(
gtsam.PriorFactorPinholeCameraCal3Bundler( PriorFactorPinholeCameraCal3Bundler(
C(0), scene_data.camera(0), gtsam.noiseModel.Isotropic.Sigma(9, 0.1) C(0), scene_data.camera(0),
) gtsam.noiseModel.Isotropic.Sigma(9, 0.1)))
)
# Also add a prior on the position of the first landmark to fix the scale # Also add a prior on the position of the first landmark to fix the scale
graph.push_back( graph.push_back(
gtsam.PriorFactorPoint3( PriorFactorPoint3(P(0),
P(0), scene_data.track(0).point3(), gtsam.noiseModel.Isotropic.Sigma(3, 0.1) scene_data.track(0).point3(),
) gtsam.noiseModel.Isotropic.Sigma(3, 0.1)))
)
# Create initial estimate # Create initial estimate
initial = gtsam.Values() initial = gtsam.Values()
i = 0 i = 0
# add each PinholeCameraCal3Bundler # add each PinholeCameraCal3Bundler
for cam_idx in range(scene_data.number_cameras()): for cam_idx in range(scene_data.number_cameras()):
@ -81,12 +88,10 @@ def run(args):
initial.insert(C(i), camera) initial.insert(C(i), camera)
i += 1 i += 1
j = 0
# add each SfmTrack # add each SfmTrack
for t_idx in range(scene_data.number_tracks()): for t_idx in range(scene_data.number_tracks()):
track = scene_data.track(t_idx) track = scene_data.track(t_idx)
initial.insert(P(j), track.point3()) initial.insert(P(t_idx), track.point3())
j += 1
# Optimize the graph and print results # Optimize the graph and print results
try: try:
@ -94,25 +99,31 @@ def run(args):
params.setVerbosityLM("ERROR") params.setVerbosityLM("ERROR")
lm = gtsam.LevenbergMarquardtOptimizer(graph, initial, params) lm = gtsam.LevenbergMarquardtOptimizer(graph, initial, params)
result = lm.optimize() result = lm.optimize()
except Exception as e: except RuntimeError:
logging.exception("LM Optimization failed") logging.exception("LM Optimization failed")
return return
# Error drops from ~2764.22 to ~0.046 # Error drops from ~2764.22 to ~0.046
logging.info(f"final error: {graph.error(result)}") logging.info("initial error: %f", graph.error(initial))
logging.info("final error: %f", graph.error(result))
plot(scene_data, result)
def main():
"""Main runner."""
parser = argparse.ArgumentParser()
parser.add_argument('-i',
'--input_file',
type=str,
default="",
help="""Read SFM data from the specified BAL file.
The data format is described here: https://grail.cs.washington.edu/projects/bal/.
BAL files contain (nrPoses, nrPoints, nrObservations), followed by (i,j,u,v) tuples,
then (wx,wy,wz,tx,ty,tz,f,k1,k1) as Bundler camera calibrations w/ Rodrigues vector
and (x,y,z) 3d point initializations.""")
run(parser.parse_args())
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() main()
parser.add_argument(
'-i',
'--input_file',
type=str,
default="dubrovnik-3-7-pre",
help='Read SFM data from the specified BAL file'
'The data format is described here: https://grail.cs.washington.edu/projects/bal/.'
'BAL files contain (nrPoses, nrPoints, nrObservations), followed by (i,j,u,v) tuples, '
'then (wx,wy,wz,tx,ty,tz,f,k1,k1) as Bundler camera calibrations w/ Rodrigues vector'
'and (x,y,z) 3d point initializations.'
)
run(parser.parse_args())