171 lines
5.4 KiB
Python
171 lines
5.4 KiB
Python
"""trt_modnet.py
|
|
|
|
This script demonstrates how to do real-time "image matting" with
|
|
TensorRT optimized MODNet engine.
|
|
"""
|
|
|
|
|
|
import argparse
|
|
|
|
import numpy as np
|
|
import cv2
|
|
import pycuda.autoinit # This is needed for initializing CUDA driver
|
|
|
|
from utils.camera import add_camera_args, Camera
|
|
from utils.writer import get_video_writer
|
|
from utils.background import Background
|
|
from utils.display import open_window, show_fps
|
|
from utils.display import FpsCalculator, ScreenToggler
|
|
from utils.modnet import TrtMODNet
|
|
|
|
|
|
WINDOW_NAME = 'TrtMODNetDemo'
|
|
|
|
|
|
def parse_args():
|
|
"""Parse input arguments."""
|
|
desc = ('Capture and display live camera video, while doing '
|
|
'real-time image matting with TensorRT optimized MODNet')
|
|
parser = argparse.ArgumentParser(description=desc)
|
|
parser = add_camera_args(parser)
|
|
parser.add_argument(
|
|
'--background', type=str, default='',
|
|
help='background image or video file name [None]')
|
|
parser.add_argument(
|
|
'--create_video', type=str, default='',
|
|
help='create output video (either .ts or .mp4) [None]')
|
|
parser.add_argument(
|
|
'--demo_mode', action='store_true',
|
|
help='run the program in a special "demo mode" [False]')
|
|
args = parser.parse_args()
|
|
return args
|
|
|
|
|
|
class BackgroundBlender():
|
|
"""BackgroundBlender
|
|
|
|
# Arguments
|
|
demo_mode: if True, do foreground/background blending in a
|
|
special "demo mode" which alternates among the
|
|
original, replaced and black backgrounds.
|
|
"""
|
|
|
|
def __init__(self, demo_mode=False):
|
|
self.demo_mode = demo_mode
|
|
self.count = 0
|
|
|
|
def blend(self, img, bg, matte):
|
|
"""Blend foreground and background using the 'matte'.
|
|
|
|
# Arguments
|
|
img: uint8 np.array of shape (H, W, 3), the foreground image
|
|
bg: uint8 np.array of shape (H, W, 3), the background image
|
|
matte: float32 np.array of shape (H, W), values between 0.0 and 1.0
|
|
"""
|
|
if self.demo_mode:
|
|
img, bg, matte = self._mod_for_demo(img, bg, matte)
|
|
return (img * matte[..., np.newaxis] +
|
|
bg * (1 - matte[..., np.newaxis])).astype(np.uint8)
|
|
|
|
def _mod_for_demo(self, img, bg, matte):
|
|
"""Modify img, bg and matte for "demo mode"
|
|
|
|
# Demo script (based on "count")
|
|
0~ 59: black background left to right
|
|
60~119: black background only
|
|
120~179: replaced background left to right
|
|
180~239: replaced background
|
|
240~299: original background left to right
|
|
300~359: original background
|
|
"""
|
|
img_h, img_w, _ = img.shape
|
|
if self.count < 120:
|
|
bg = np.zeros(bg.shape, dtype=np.uint8)
|
|
if self.count < 60:
|
|
offset = int(img_w * self.count / 59)
|
|
matte[:, offset:img_w] = 1.0
|
|
elif self.count < 240:
|
|
if self.count < 180:
|
|
offset = int(img_w * (self.count - 120) / 59)
|
|
bg[:, offset:img_w, :] = 0
|
|
else:
|
|
if self.count < 300:
|
|
offset = int(img_w * (self.count - 240) / 59)
|
|
matte[:, 0:offset] = 1.0
|
|
else:
|
|
matte[:, :] = 1.0
|
|
self.count = (self.count + 1) % 360
|
|
return img, bg, matte
|
|
|
|
|
|
class TrtMODNetRunner():
|
|
"""TrtMODNetRunner
|
|
|
|
# Arguments
|
|
modnet: TrtMODNet instance
|
|
cam: Camera object (for reading foreground images)
|
|
bggen: background generator (for reading background images)
|
|
blender: BackgroundBlender object
|
|
writer: VideoWriter object (for saving output video)
|
|
"""
|
|
|
|
def __init__(self, modnet, cam, bggen, blender, writer=None):
|
|
self.modnet = modnet
|
|
self.cam = cam
|
|
self.bggen = bggen
|
|
self.blender = blender
|
|
self.writer = writer
|
|
open_window(
|
|
WINDOW_NAME, 'TensorRT MODNet Demo', cam.img_width, cam.img_height)
|
|
|
|
def run(self):
|
|
"""Get img and bg, infer matte, blend and show img, then repeat."""
|
|
scrn_tog = ScreenToggler()
|
|
fps_calc = FpsCalculator()
|
|
while True:
|
|
if cv2.getWindowProperty(WINDOW_NAME, 0) < 0: break
|
|
img, bg = self.cam.read(), self.bggen.read()
|
|
if img is None: break
|
|
matte = self.modnet.infer(img)
|
|
matted_img = self.blender.blend(img, bg, matte)
|
|
fps = fps_calc.update()
|
|
matted_img = show_fps(matted_img, fps)
|
|
if self.writer: self.writer.write(matted_img)
|
|
cv2.imshow(WINDOW_NAME, matted_img)
|
|
key = cv2.waitKey(1)
|
|
if key == ord('F') or key == ord('f'): # Toggle fullscreen
|
|
scrn_tog.toggle()
|
|
elif key == 27: # ESC key: quit
|
|
break
|
|
|
|
def __del__(self):
|
|
cv2.destroyAllWindows()
|
|
|
|
|
|
def main():
|
|
args = parse_args()
|
|
|
|
cam = Camera(args)
|
|
if not cam.isOpened():
|
|
raise SystemExit('ERROR: failed to open camera!')
|
|
|
|
writer = None
|
|
if args.create_video:
|
|
writer = get_video_writer(
|
|
args.create_video, cam.img_width, cam.img_height)
|
|
|
|
modnet = TrtMODNet()
|
|
bggen = Background(args.background, cam.img_width, cam.img_height)
|
|
blender = BackgroundBlender(args.demo_mode)
|
|
|
|
runner = TrtMODNetRunner(modnet, cam, bggen, blender, writer)
|
|
runner.run()
|
|
|
|
if writer:
|
|
writer.release()
|
|
cam.release()
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|