269 lines
10 KiB
Python
269 lines
10 KiB
Python
#!/usr/bin/env python
|
|
from __future__ import print_function
|
|
import os, sys, shutil
|
|
import argparse
|
|
import json, re
|
|
from subprocess import check_output
|
|
import datetime
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
|
def load_json(path):
|
|
f = open(path, "r")
|
|
data = json.load(f)
|
|
return data
|
|
|
|
|
|
def save_json(obj, path):
|
|
tmp_file = path + ".bak"
|
|
f = open(tmp_file, "w")
|
|
json.dump(obj, f, indent=2)
|
|
f.flush()
|
|
os.fsync(f.fileno())
|
|
f.close()
|
|
try:
|
|
os.rename(tmp_file, path)
|
|
except:
|
|
os.remove(path)
|
|
os.rename(tmp_file, path)
|
|
|
|
|
|
def parse_evaluation_result(input_str, i):
|
|
res = {}
|
|
res['frame_number'] = i + 1
|
|
res['error'] = {}
|
|
regex = "([A-Za-z. \\[\\].0-9]+):[ ]*([0-9]*\.[0-9]+|[0-9]+)"
|
|
for elem in re.findall(regex,input_str):
|
|
if "Time" in elem[0]:
|
|
res['time'] = float(elem[1])
|
|
elif "Average" in elem[0]:
|
|
res['error']['average'] = float(elem[1])
|
|
elif "deviation" in elem[0]:
|
|
res['error']['std'] = float(elem[1])
|
|
else:
|
|
res['error'][elem[0]] = float(elem[1])
|
|
return res
|
|
|
|
|
|
def evaluate_sequence(sequence, algorithm, dataset, executable, img_files, gt_files,
|
|
state, state_path):
|
|
if "eval_results" not in state[dataset][algorithm][-1].keys():
|
|
state[dataset][algorithm][-1]["eval_results"] = {}
|
|
elif sequence in state[dataset][algorithm][-1]["eval_results"].keys():
|
|
return
|
|
|
|
res = []
|
|
for i in range(len(img_files) - 1):
|
|
sys.stdout.write("Algorithm: %-20s Sequence: %-10s Done: [%3d/%3d]\r" %
|
|
(algorithm, sequence, i, len(img_files) - 1)),
|
|
sys.stdout.flush()
|
|
|
|
res_string = check_output([executable, img_files[i], img_files[i + 1],
|
|
algorithm, gt_files[i]])
|
|
res.append(parse_evaluation_result(res_string, i))
|
|
state[dataset][algorithm][-1]["eval_results"][sequence] = res
|
|
save_json(state, state_path)
|
|
|
|
#############################DATSET DEFINITIONS################################
|
|
|
|
def evaluate_mpi_sintel(source_dir, algorithm, evaluation_executable, state, state_path):
|
|
evaluation_result = {}
|
|
img_dir = os.path.join(source_dir, 'mpi_sintel', 'training', 'final')
|
|
gt_dir = os.path.join(source_dir, 'mpi_sintel', 'training', 'flow')
|
|
sequences = [f for f in os.listdir(img_dir)
|
|
if os.path.isdir(os.path.join(img_dir, f))]
|
|
for seq in sequences:
|
|
img_files = sorted([os.path.join(img_dir, seq, f)
|
|
for f in os.listdir(os.path.join(img_dir, seq))
|
|
if f.endswith(".png")])
|
|
gt_files = sorted([os.path.join(gt_dir, seq, f)
|
|
for f in os.listdir(os.path.join(gt_dir, seq))
|
|
if f.endswith(".flo")])
|
|
evaluation_result[seq] = evaluate_sequence(seq, algorithm, 'mpi_sintel',
|
|
evaluation_executable, img_files, gt_files, state, state_path)
|
|
return evaluation_result
|
|
|
|
|
|
def evaluate_middlebury(source_dir, algorithm, evaluation_executable, state, state_path):
|
|
evaluation_result = {}
|
|
img_dir = os.path.join(source_dir, 'middlebury', 'other-data')
|
|
gt_dir = os.path.join(source_dir, 'middlebury', 'other-gt-flow')
|
|
sequences = [f for f in os.listdir(gt_dir)
|
|
if os.path.isdir(os.path.join(gt_dir, f))]
|
|
for seq in sequences:
|
|
img_files = sorted([os.path.join(img_dir, seq, f)
|
|
for f in os.listdir(os.path.join(img_dir, seq))
|
|
if f.endswith(".png")])
|
|
gt_files = sorted([os.path.join(gt_dir, seq, f)
|
|
for f in os.listdir(os.path.join(gt_dir, seq))
|
|
if f.endswith(".flo")])
|
|
evaluation_result[seq] = evaluate_sequence(seq, algorithm, 'middlebury',
|
|
evaluation_executable, img_files, gt_files, state, state_path)
|
|
return evaluation_result
|
|
|
|
|
|
dataset_eval_functions = {
|
|
"mpi_sintel": evaluate_mpi_sintel,
|
|
"middlebury": evaluate_middlebury
|
|
}
|
|
|
|
###############################################################################
|
|
|
|
def create_dir(dir):
|
|
if not os.path.exists(dir):
|
|
os.makedirs(dir)
|
|
|
|
|
|
def parse_sequence(input_str):
|
|
if len(input_str) == 0:
|
|
return []
|
|
else:
|
|
return [o.strip() for o in input_str.split(",") if o]
|
|
|
|
|
|
def build_chart(dst_folder, state, dataset):
|
|
fig = plt.figure(figsize=(16, 10))
|
|
markers = ["o", "s", "h", "^", "D"]
|
|
marker_idx = 0
|
|
colors = ["b", "g", "r"]
|
|
color_idx = 0
|
|
for algo in state[dataset].keys():
|
|
for eval_instance in state[dataset][algo]:
|
|
name = algo + "--" + eval_instance["timestamp"]
|
|
average_time = 0.0
|
|
average_error = 0.0
|
|
num_elem = 0
|
|
for seq in eval_instance["eval_results"].keys():
|
|
for frame in eval_instance["eval_results"][seq]:
|
|
average_time += frame["time"]
|
|
average_error += frame["error"]["average"]
|
|
num_elem += 1
|
|
average_time /= num_elem
|
|
average_error /= num_elem
|
|
|
|
marker_style = colors[color_idx] + markers[marker_idx]
|
|
color_idx += 1
|
|
if color_idx >= len(colors):
|
|
color_idx = 0
|
|
marker_idx += 1
|
|
if marker_idx >= len(markers):
|
|
marker_idx = 0
|
|
plt.gca().plot([average_time], [average_error],
|
|
marker_style,
|
|
markersize=14,
|
|
label=name)
|
|
|
|
plt.gca().set_ylabel('Average Endpoint Error (EPE)', fontsize=20)
|
|
plt.gca().set_xlabel('Average Runtime (seconds per frame)', fontsize=20)
|
|
plt.gca().set_xscale("log")
|
|
plt.gca().set_title('Evaluation on ' + dataset, fontsize=20)
|
|
|
|
plt.gca().legend()
|
|
fig.savefig(os.path.join(dst_folder, "evaluation_results_" + dataset + ".png"),
|
|
bbox_inches='tight')
|
|
plt.close()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser(
|
|
description='Optical flow benchmarking script',
|
|
formatter_class=argparse.RawDescriptionHelpFormatter)
|
|
parser.add_argument(
|
|
"bin_path",
|
|
default="./optflow-example-optical_flow_evaluation",
|
|
help="Path to the optical flow evaluation executable")
|
|
parser.add_argument(
|
|
"-a",
|
|
"--algorithms",
|
|
metavar="ALGORITHMS",
|
|
default="",
|
|
help=("Comma-separated list of optical-flow algorithms to evaluate "
|
|
"(example: -a farneback,tvl1,deepflow). Note that previously "
|
|
"evaluated algorithms are also included in the output charts"))
|
|
parser.add_argument(
|
|
"-d",
|
|
"--datasets",
|
|
metavar="DATASETS",
|
|
default="mpi_sintel",
|
|
help=("Comma-separated list of datasets for evaluation (currently only "
|
|
"'mpi_sintel' and 'middlebury' are supported)"))
|
|
parser.add_argument(
|
|
"-f",
|
|
"--dataset_folder",
|
|
metavar="DATASET_FOLDER",
|
|
default="./OF_datasets",
|
|
help=("Path to a folder containing datasets. To enable evaluation on "
|
|
"MPI Sintel dataset, please download it using the following links: "
|
|
"http://files.is.tue.mpg.de/sintel/MPI-Sintel-training_images.zip and "
|
|
"http://files.is.tue.mpg.de/sintel/MPI-Sintel-training_extras.zip and "
|
|
"unzip these archives into the 'mpi_sintel' folder. To enable evaluation "
|
|
"on the Middlebury dataset use the following links: "
|
|
"http://vision.middlebury.edu/flow/data/comp/zip/other-color-twoframes.zip, "
|
|
"http://vision.middlebury.edu/flow/data/comp/zip/other-gt-flow.zip. "
|
|
"These should be unzipped into 'middlebury' folder"))
|
|
parser.add_argument(
|
|
"-o",
|
|
"--out",
|
|
metavar="OUT_DIR",
|
|
default="./OF_evaluation_results",
|
|
help="Output directory where to store benchmark results")
|
|
parser.add_argument(
|
|
"-s",
|
|
"--state",
|
|
metavar="STATE_JSON",
|
|
default="./OF_evaluation_state.json",
|
|
help=("Path to a json file that stores the current evaluation state and "
|
|
"previous evaluation results"))
|
|
args, other_args = parser.parse_known_args()
|
|
|
|
if not os.path.isfile(args.bin_path):
|
|
print("Error: " + args.bin_path + " does not exist")
|
|
sys.exit(1)
|
|
|
|
if not os.path.exists(args.dataset_folder):
|
|
print("Error: " + args.dataset_folder + (" does not exist. Please, correctly "
|
|
"specify the -f parameter"))
|
|
sys.exit(1)
|
|
|
|
state = {}
|
|
if os.path.isfile(args.state):
|
|
state = load_json(args.state)
|
|
|
|
algorithm_list = parse_sequence(args.algorithms)
|
|
dataset_list = parse_sequence(args.datasets)
|
|
for dataset in dataset_list:
|
|
if dataset not in dataset_eval_functions.keys():
|
|
print("Error: unsupported dataset " + dataset)
|
|
sys.exit(1)
|
|
if dataset not in os.listdir(args.dataset_folder):
|
|
print("Error: " + os.path.join(args.dataset_folder, dataset) + (" does not exist. "
|
|
"Please, download the dataset and follow the naming conventions "
|
|
"(use -h for more information)"))
|
|
sys.exit(1)
|
|
|
|
for dataset in dataset_list:
|
|
if dataset not in state.keys():
|
|
state[dataset] = {}
|
|
for algorithm in algorithm_list:
|
|
if algorithm in state[dataset].keys():
|
|
last_eval_instance = state[dataset][algorithm][-1]
|
|
if "finished" not in last_eval_instance.keys():
|
|
print(("Continuing an unfinished evaluation of " +
|
|
algorithm + " started at " + last_eval_instance["timestamp"]))
|
|
else:
|
|
state[dataset][algorithm].append({"timestamp":
|
|
datetime.datetime.now().strftime("%Y-%m-%d--%H-%M")})
|
|
else:
|
|
state[dataset][algorithm] = [{"timestamp":
|
|
datetime.datetime.now().strftime("%Y-%m-%d--%H-%M")}]
|
|
save_json(state, args.state)
|
|
dataset_eval_functions[dataset](args.dataset_folder, algorithm, args.bin_path,
|
|
state, args.state)
|
|
state[dataset][algorithm][-1]["finished"] = True
|
|
save_json(state, args.state)
|
|
save_json(state, args.state)
|
|
|
|
create_dir(args.out)
|
|
for dataset in dataset_list:
|
|
build_chart(args.out, state, dataset)
|