本案例演示在 OpenVINO 中使用 MidasNet 进行单目深度估计,输入视频情况。模型信息可以在 这里找到。
环境描述:
本案例运行环境:Win10
IDE:VSCode
openvino版本:2022.1
代码链接,3-monodepth-imaging
文章目录
- openvino系列 7. 单目深度估算,输入为视频
- 单目深度估算的基本概念
- MidasNet的基本介绍
- 单目深度估算在视频中的应用
单目深度估算的基本概念
深度估计就是从RGB图像中估计图像中物体的深度,是一个从二维到三维的艰难过程。说道测距,我们首先会想到使用双目摄像头或者激光雷达,当然,这些方法各有优缺点,比如比如体积大(TOF)、能耗高(Kinect配有散热系统)、受环境影响(阳光中红外线影响)、算法复杂度高、实时性差(TOF实时性最高但精度较低)等。对于单目深度估算,其先天缺陷就是无法通过传感器直接得到精确的距离信息,但是随着软件算法的发展,我们可以通过深度学习来弥补硬件上的不足,同时为其他图像应用如语义分割、物体识别等提供更多的特征信息。
我们知道,就算我们闭上一只眼,也可以对眼前物体的距离有一个判断。 那也就是说,我们可以通过深度学习,希望机器能拥有像人脑一样的学习能力,2D图像的距离信息有一个估算。
MidasNet的基本介绍
在这个演示中,我们使用了一个名为MiDaS 的神经网络模型。论文出处:
R. Ranftl, K. Lasinger, D. Hafner, K. Schindler and V. Koltun, “Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer,” in IEEE Transactions on Pattern Analysis and Machine Intelligence, doi: 10.1109/TPAMI.2020.3019967.
这篇文章提出了一种监督的深度估计方法,具体来讲文章的策略可以归纳为: 1)使用多个深度数据集(各自拥有不同的scale和shift属性)加入进行训练,增大数据量与实现场景的互补; 2)提出了一种深度和偏移不变性的损失函数用于去监督深度的回归过程,从而使得可以更加有效使用现有数据; 3)采用从3D电影中进行采样的方式扩充数据集,从而进一步增加数据量; 4)使用带有原则属性的多目标训练方法,从而得到一种更加行之有效的优化方法; 结合上述的优化策略与方法,文章的最后得到的模型具有较强的泛化能力,从而摆脱了之前一些公开数据集场景依赖严重的问题。
单目深度估算在视频中的应用
MidasNet在视频中的应用实际上和图像类似,唯一的区别在于,视频是以多张图像的形式展现。这个案例中,我们将读取一个视频,并对于视频中的图像进行深度估算,返回深度图,并组成后处理视频,保存在本地。
代码整体逻辑:
- 首先,我们需要读取模型(ie.read_model)并且编译(ie.compile_model);
- 第二步,我们读取视频,并且我们设置的一些参数,比如FPS,视频图片缩放比,对我们即将输入模型的图像进行预处理;
- 第三步,对于视频中的每一帧输入图像,reshape其大小以符合模型的输入(reshape为 (N,C,H,W)(N=图像数,C=通道数,H=高度,W=宽度));
- 第四步,模型推理(compiled_model([input_image])[output_key])。得到的结果的尺寸和模型的输出尺寸相符。然后,我们将输出的结果转化为RGB图(通过函数convert_result_to_image),将其尺寸转换回输入是的图像大小,最后可视化结果。
- 第五步,我们将模型推理完的图片保存到本地(out_video.write(stacked_frame))。图片的集合就变成了视频。
代码如下:
import sys
import time
from pathlib import Path
import cv2
import matplotlib.cm
import matplotlib.pyplot as plt
import numpy as np
from IPython.display import (
HTML,
FileLink,
Pretty,
ProgressBar,
Video,
clear_output,
display,
)
from openvino.runtime import Core
DEVICE = "CPU"
MODEL_FILE = "model/MiDaS_small.xml"
model_xml_path = Path(MODEL_FILE)
def normalize_minmax(data):
"""
Normalizes the values in `data` between 0 and 1
"""
return (data - data.min()) / (data.max() - data.min())
def convert_result_to_image(result, colormap="viridis"):
"""
Convert network result of floating point numbers to an RGB image with
integer values from 0-255 by applying a colormap.
`result` is expected to be a single network result in 1,H,W shape
`colormap` is a matplotlib colormap.
See https://matplotlib.org/stable/tutorials/colors/colormaps.html
"""
cmap = matplotlib.cm.get_cmap(colormap)
result = result.squeeze(0)
result = normalize_minmax(result)
result = cmap(result)[:, :, :3] * 255
result = result.astype(np.uint8)
return result
def to_rgb(image_data) -> np.ndarray:
"""
Convert image_data from BGR to RGB
"""
return cv2.cvtColor(image_data, cv2.COLOR_BGR2RGB)
print("1 - Load Model")
ie = Core()
model = ie.read_model(model=model_xml_path, weights=model_xml_path.with_suffix(".bin"))
compiled_model = ie.compile_model(model=model, device_name=DEVICE)
input_key = compiled_model.input(0)
output_key = compiled_model.output(0)
print("- Input layer info: {}".format(input_key))
print("- Output layer info: {}".format(output_key))
network_input_shape = list(input_key.shape)
network_image_height, network_image_width = network_input_shape[2:]
print("- Setup video parameters.")
# Video source: https://www.youtube.com/watch?v=fu1xcQdJRws (Public Domain)
VIDEO_FILE = "data/cat-dog.mp4"
# Number of seconds of input video to process. Set to 0 to process the full video.
NUM_SECONDS = 0
# Set ADVANCE_FRAMES to 1 to process every frame from the input video
# Set ADVANCE_FRAMES to 2 to process every second frame. This reduces the time it takes to process the video
ADVANCE_FRAMES = 2
# Set SCALE_OUTPUT to reduce the size of the result video
SCALE_OUTPUT = 0.5
# The format to use for video encoding. vp09 is slow, but it works on most systems.
# Try the THEO encoding if you have FFMPEG installed.
# FOURCC = cv2.VideoWriter_fourcc(*"THEO")
FOURCC = cv2.VideoWriter_fourcc(*"vp09")
# Create Path objects for the input video and the resulting video
output_directory = Path("output")
output_directory.mkdir(exist_ok=True)
result_video_path = output_directory / f"{Path(VIDEO_FILE).stem}_monodepth.mp4"
print("2 - load video file.")
cap = cv2.VideoCapture(str(VIDEO_FILE))
ret, image = cap.read()
if not ret:
raise ValueError(f"The video at {VIDEO_FILE} cannot be read.")
input_fps = cap.get(cv2.CAP_PROP_FPS)
input_video_frame_height, input_video_frame_width = image.shape[:2]
target_fps = input_fps / ADVANCE_FRAMES
target_frame_height = int(input_video_frame_height * SCALE_OUTPUT)
target_frame_width = int(input_video_frame_width * SCALE_OUTPUT)
cap.release()
print("- The input video has a frame width of {}, frame height of {} and runs at {} fps."\
.format(input_video_frame_width, input_video_frame_height, int(input_fps)))
print("- The monodepth video will be scaled with a factor {}, have width {}, height {} and runs at {} fps."\
.format(SCALE_OUTPUT, target_frame_width, target_frame_height, int(target_fps)))
# Initialize variables
input_video_frame_nr = 0
start_time = time.perf_counter()
total_inference_duration = 0
# Open input video
cap = cv2.VideoCapture(str(VIDEO_FILE))
# Create result video
out_video = cv2.VideoWriter(
str(result_video_path),
FOURCC,
target_fps,
(target_frame_width * 2, target_frame_height),
)
num_frames = int(NUM_SECONDS * input_fps)
total_frames = cap.get(cv2.CAP_PROP_FRAME_COUNT) if num_frames == 0 else num_frames
progress_bar = ProgressBar(total=total_frames)
progress_bar.display()
try:
while cap.isOpened():
ret, image = cap.read()
if not ret:
cap.release()
break
if input_video_frame_nr >= total_frames:
break
# Only process every second frame
# Prepare frame for inference
# resize to input shape for network
resized_image = cv2.resize(src=image, dsize=(network_image_height, network_image_width))
# reshape image to network input shape NCHW
input_image = np.expand_dims(np.transpose(resized_image, (2, 0, 1)), 0)
#print("3 - Load Image. Reshape image from {} to {}".format(image.shape, input_image.shape))
# Do inference
inference_start_time = time.perf_counter()
result = compiled_model([input_image])[output_key]
inference_stop_time = time.perf_counter()
inference_duration = inference_stop_time - inference_start_time
total_inference_duration += inference_duration
#print("4 - Model Inference. Inference result shape: {}".format(result.shape))
if input_video_frame_nr % (10 * ADVANCE_FRAMES) == 0:
clear_output(wait=True)
progress_bar.display()
# input_video_frame_nr // ADVANCE_FRAMES gives the number of
# frames that have been processed by the network
display(
Pretty(
f"Processed frame {input_video_frame_nr // ADVANCE_FRAMES}"
f"/{total_frames // ADVANCE_FRAMES}. "
f"Inference time per frame: {inference_duration:.2f} seconds "
f"({1/inference_duration:.2f} FPS)"
)
)
# Transform network result to RGB image
result_frame = to_rgb(convert_result_to_image(result))
# Resize image and result to target frame shape
result_frame = cv2.resize(result_frame, (target_frame_width, target_frame_height))
image = cv2.resize(image, (target_frame_width, target_frame_height))
# Put image and result side by side
stacked_frame = np.hstack((image, result_frame))
# Save frame to video
out_video.write(stacked_frame)
#print("5 - Write the results after post processing into video and finally saved in local path.")
input_video_frame_nr = input_video_frame_nr + ADVANCE_FRAMES
cap.set(1, input_video_frame_nr)
progress_bar.progress = input_video_frame_nr
progress_bar.update()
except KeyboardInterrupt:
print("Processing interrupted.")
finally:
clear_output()
processed_frames = num_frames // ADVANCE_FRAMES
out_video.release()
cap.release()
end_time = time.perf_counter()
duration = end_time - start_time
print(
f"Processed {processed_frames} frames in {duration:.2f} seconds. "
f"Total FPS (including video processing): {processed_frames/duration:.2f}."
f"Inference FPS: {processed_frames/total_inference_duration:.2f} "
)
print(f"Monodepth Video saved to '{str(result_video_path)}'.")
我们打开处理后的视频,会发现,MidasNet这个算法对于通用的大物体深度图效果普遍比较好,如下图
但是对于小物体,或者不是那么大众的视角拍过去的物体,深度图效果就很一般了,如下图: