TensorRT-Demo/trt_modnet.py

171 lines
5.4 KiB
Python

"""trt_modnet.py
This script demonstrates how to do real-time "image matting" with
TensorRT optimized MODNet engine.
"""
import argparse
import numpy as np
import cv2
import pycuda.autoinit # This is needed for initializing CUDA driver
from utils.camera import add_camera_args, Camera
from utils.writer import get_video_writer
from utils.background import Background
from utils.display import open_window, show_fps
from utils.display import FpsCalculator, ScreenToggler
from utils.modnet import TrtMODNet
WINDOW_NAME = 'TrtMODNetDemo'
def parse_args():
"""Parse input arguments."""
desc = ('Capture and display live camera video, while doing '
'real-time image matting with TensorRT optimized MODNet')
parser = argparse.ArgumentParser(description=desc)
parser = add_camera_args(parser)
parser.add_argument(
'--background', type=str, default='',
help='background image or video file name [None]')
parser.add_argument(
'--create_video', type=str, default='',
help='create output video (either .ts or .mp4) [None]')
parser.add_argument(
'--demo_mode', action='store_true',
help='run the program in a special "demo mode" [False]')
args = parser.parse_args()
return args
class BackgroundBlender():
"""BackgroundBlender
# Arguments
demo_mode: if True, do foreground/background blending in a
special "demo mode" which alternates among the
original, replaced and black backgrounds.
"""
def __init__(self, demo_mode=False):
self.demo_mode = demo_mode
self.count = 0
def blend(self, img, bg, matte):
"""Blend foreground and background using the 'matte'.
# Arguments
img: uint8 np.array of shape (H, W, 3), the foreground image
bg: uint8 np.array of shape (H, W, 3), the background image
matte: float32 np.array of shape (H, W), values between 0.0 and 1.0
"""
if self.demo_mode:
img, bg, matte = self._mod_for_demo(img, bg, matte)
return (img * matte[..., np.newaxis] +
bg * (1 - matte[..., np.newaxis])).astype(np.uint8)
def _mod_for_demo(self, img, bg, matte):
"""Modify img, bg and matte for "demo mode"
# Demo script (based on "count")
0~ 59: black background left to right
60~119: black background only
120~179: replaced background left to right
180~239: replaced background
240~299: original background left to right
300~359: original background
"""
img_h, img_w, _ = img.shape
if self.count < 120:
bg = np.zeros(bg.shape, dtype=np.uint8)
if self.count < 60:
offset = int(img_w * self.count / 59)
matte[:, offset:img_w] = 1.0
elif self.count < 240:
if self.count < 180:
offset = int(img_w * (self.count - 120) / 59)
bg[:, offset:img_w, :] = 0
else:
if self.count < 300:
offset = int(img_w * (self.count - 240) / 59)
matte[:, 0:offset] = 1.0
else:
matte[:, :] = 1.0
self.count = (self.count + 1) % 360
return img, bg, matte
class TrtMODNetRunner():
"""TrtMODNetRunner
# Arguments
modnet: TrtMODNet instance
cam: Camera object (for reading foreground images)
bggen: background generator (for reading background images)
blender: BackgroundBlender object
writer: VideoWriter object (for saving output video)
"""
def __init__(self, modnet, cam, bggen, blender, writer=None):
self.modnet = modnet
self.cam = cam
self.bggen = bggen
self.blender = blender
self.writer = writer
open_window(
WINDOW_NAME, 'TensorRT MODNet Demo', cam.img_width, cam.img_height)
def run(self):
"""Get img and bg, infer matte, blend and show img, then repeat."""
scrn_tog = ScreenToggler()
fps_calc = FpsCalculator()
while True:
if cv2.getWindowProperty(WINDOW_NAME, 0) < 0: break
img, bg = self.cam.read(), self.bggen.read()
if img is None: break
matte = self.modnet.infer(img)
matted_img = self.blender.blend(img, bg, matte)
fps = fps_calc.update()
matted_img = show_fps(matted_img, fps)
if self.writer: self.writer.write(matted_img)
cv2.imshow(WINDOW_NAME, matted_img)
key = cv2.waitKey(1)
if key == ord('F') or key == ord('f'): # Toggle fullscreen
scrn_tog.toggle()
elif key == 27: # ESC key: quit
break
def __del__(self):
cv2.destroyAllWindows()
def main():
args = parse_args()
cam = Camera(args)
if not cam.isOpened():
raise SystemExit('ERROR: failed to open camera!')
writer = None
if args.create_video:
writer = get_video_writer(
args.create_video, cam.img_width, cam.img_height)
modnet = TrtMODNet()
bggen = Background(args.background, cam.img_width, cam.img_height)
blender = BackgroundBlender(args.demo_mode)
runner = TrtMODNetRunner(modnet, cam, bggen, blender, writer)
runner.run()
if writer:
writer.release()
cam.release()
if __name__ == '__main__':
main()