前述:
本文为了记载自己在开发过程中的心得以及问题解决而写,阅读者在参考时尽量加入自己的思考,毕竟依赖库以及主程序随着时间的变化正在不断升级; 另外,牢骚一下,对于初学者,很多作者在写文章的时候,会漏掉诸多环节,对于他们而言,那是不言而喻的东西,对于初学者,那些是什么,怎么做?
参考地址:
https://tensorflow.google.cn/install/pip#windows
重新安装python:
python-3.7.9.exe是32位版本的python,在寻找对应的tensorflow时,无法找到,试了很多的方法均无法匹配,因此删除3.7.9版本,改为使用【python-3.7.9-amd64.exe】;同样的,openCV也需要安装64位版本; 安装文件如何寻找,可以参见博主之前的文章python+openCV安装(WINDOWS环境)
安装tensorflow:
在命令行中执行:pip3 install tensorflow 在下图中可以看到可以自动匹配到tensorflow的版本并且进行自动下载安装,但是由于网速有限,依然需要在https://pypi.org/上找到【tensorflow-2.3.0-cp37-cp37m-win_amd64.whl】文件下载后本地安装。
根据博主之前的文章python+openCV安装(WINDOWS环境)目前初始环境已经具备,即:Python 3.7.9;tensorflow 2.3.0;opencv 4.4.0
COCO数据集:
博主把下边这个表格拷贝过来,这个表格代表了每一种模型的速度、识别率等情况,表格中的超链接可以直接链接到下载地址,如果无法下载,可以拷贝超链接后,使用百度网盘等工具中的离线下载,亲测好用。 另外,其他模型暂时不考虑,主要是初学,贪多嚼不烂。
Model name | Speed (ms) | COCO mAP1 | Outputs |
30 | 21 | Boxes | |
26 | 18 | Boxes | |
29 | 18 | Boxes | |
29 | 16 | Boxes | |
26 | 20 | Boxes | |
56 | 32 | Boxes | |
76 | 35 | Boxes | |
31 | 22 | Boxes | |
29 | 22 | Boxes | |
27 | 22 | Boxes | |
42 | 24 | Boxes | |
58 | 28 | Boxes | |
89 | 30 | Boxes | |
64 | Boxes | ||
92 | 30 | Boxes | |
106 | 32 | Boxes | |
82 | Boxes | ||
620 | 37 | Boxes | |
241 | Boxes | ||
1833 | 43 | Boxes | |
540 | Boxes | ||
771 | 36 | Masks | |
79 | 25 | Masks | |
470 | 33 | Masks | |
343 | 29 | Masks |
这里以ssd_mobilenet_v2_coco_2018_03_29模型为例,
1、将其解压后,文件夹样式如图:
2、增加一个文件classes.txt,这个里面是将需要识别的类别名称列举到其中,博主把已知的所有名称都列举进去了(自己理解越少,识别起来越速度)。
3、书写python代码,并且调整某些参数:
########################################################
## DIC:tensorFlow auto recognition SSD Demo
## Base on ssd_mobilenet_v2_coco
## Depend On:Python 3.7.9;tensorflow 2.3.0;opencv 4.4.0
## Author:J.Y.Zhang 2020-09-11
## Ver:1.0
## See:## See:
########################################################
import os
import numpy as np
import cv2
import matplotlib.pyplot as plt
import tensorflow as tf
# 加载coco数据集模型 TODO 需要改为自己文件路径
model_path = "F:/traindata/ssd_mobilenet_v2_coco_2018_03_29"
frozen_pb_file = os.path.join(model_path, 'frozen_inference_graph.pb')
# 加载coco数据集分类 TODO 需要改为自己文件路径
f = open("F:/traindata/ssd_mobilenet_v2_coco_2018_03_29/classes.txt", "r")
class_names = f.readlines()
# 此值决定了何种情况下判定为元素,取值越低认定的可能性越高,相应的图中的框就越多,识别为假的可能性越高
# TODO 如果需要则调整
score_threshold = 0.3
# 需要识别的图片的位置以及名称 TODO 需要改为自己文件路径
img_file = 'F:/traindata/labelImgdata/img/4.jpg'
# resize img for a format size TODO 可以修改成为自己想要的尺寸
RESIZE_IMG_WIDTH = 500
RESIZE_IMG_HEIGHT = 330
# Read the graph.
with tf.gfile.FastGFile(frozen_pb_file, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
with tf.Session() as sess:
# Restore session
sess.graph.as_default()
tf.import_graph_def(graph_def, name='')
# Read and preprocess an image.
img_cv2 = cv2.imread(img_file)
img_height, img_width, _ = img_cv2.shape
# 对图片的长宽进行格式化,防止图片过大
img_in = cv2.resize(img_cv2, (RESIZE_IMG_WIDTH, RESIZE_IMG_HEIGHT))
img_in = img_in[:, :, [2, 1, 0]] # BGR2RGB
# Run the model
outputs = sess.run([sess.graph.get_tensor_by_name('num_detections:0'),
sess.graph.get_tensor_by_name('detection_scores:0'),
sess.graph.get_tensor_by_name('detection_boxes:0'),
sess.graph.get_tensor_by_name('detection_classes:0')],
feed_dict={'image_tensor:0': img_in.reshape(1, img_in.shape[0],
img_in.shape[1], 3)})
# Visualize detected bounding boxes.
num_detections = int(outputs[0][0])
for i in range(num_detections):
classId = int(outputs[3][0][i])
score = float(outputs[1][0][i])
bbox = [float(v) for v in outputs[2][0][i]]
if score > score_threshold:
x = bbox[1] * img_width
y = bbox[0] * img_height
right = bbox[3] * img_width
bottom = bbox[2] * img_height
# 标框
cv2.rectangle(img_cv2, (int(x), int(y)), (int(right), int(bottom)), (125, 255, 51), thickness=3)
# 文字"class_name, score"
cv2.putText(img_cv2, class_names[classId - 1][:-1] + "," + str("%.2f" % score), (int(x), int(y)),
cv2.FONT_HERSHEY_DUPLEX, .5, (0, 0, 255), 1)
print(str(classId) + ",class:" + class_names[classId - 1][:-1] + ",score:%.2f" % score)
# figsize : 指定figure的宽和高,单位为英寸
plt.figure(figsize=(10, 8))
plt.imshow(img_cv2[:, :, ::-1])
plt.title("TensorFlow MobileNetV2-SSD")
plt.axis("off")
plt.show()
4、执行代码:
我们可以尝试使用不同的图片、修改代码中的score_threshold值等方式查看执行情况。