图像特征
和文本特征类似,图像特征也是梯度提升树模型非常难以挖掘的一类数据,目前图像相关的问题,例如图像分类,图像分割等等几乎都是以神经网络为主的模型,但是在一些多模态的问题中,例如商品搜索推荐的问题中,里面既包含图像信息又含有文本信息等,这个时候基于梯度提升树模型的建模方案还是至关重要的,这个时候为了更好地使用所有的数据信息,我们需要对图像特征进行多方位的提取。
本节我们接着上一节10大特征之后再补充另外的一些最为经典的特征。
1.图像预训练特征
目前预训练的图像特征非常的多,典型的有:
- VGG-16
- resnet50
- xception
- inception_v3
- EfficientNet
- NFNet
- 其它的,参见链接
预训练好的模型可以拿过来使用,我们仅仅只需要将图片转化为对应模型需要的输入形式,然后输入模型,一般我们会将模型预测结果亦或者是最后几层的特征拿出来作为图像的特征,该特征在多模态的数据竞赛中基本也是获奖选手使用最多的特征。
注:因为预训练的网络模型最终的输出结果往往是非常大的,这个时候可以考虑对其进行降维操作。
- 抽取ResNet50的特征作为我们的图像特征。
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.applications.resnet50 import preprocess_input
from PIL import Image
from tensorflow.keras.preprocessing.image import load_img
from tensorflow.keras.preprocessing.image import img_to_array
model = ResNet50(weights="imagenet", include_top=False)
img = load_img('./imgs/chapter7/img_example.jpeg', target_size=(224, 224))
img = img_to_array(img)
img = np.expand_dims(img, axis=0)
img = preprocess_input(img)
res50_features = model.predict(img)
res50_features.shape
WARNING:tensorflow:6 out of the last 9 calls to <function Model.make_predict_function.<locals>.predict_function at 0x7f9109f04b80> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.
(1, 7, 7, 2048)
- inception_v3的预测结果的TopN作为图像特征
import os
import numpy as np
import pandas as pd
from tensorflow.keras.preprocessing import image
import tensorflow.keras.applications.resnet50 as resnet50
import tensorflow.keras.applications.xception as xception
import tensorflow.keras.applications.inception_v3 as inception_v3
inception_model = inception_v3.InceptionV3(weights='imagenet')
def image_classify(model, pak, img, top_n=3):
"""Classify image and return top matches."""
target_size = (299, 299)
if img.size != target_size:
img = img.resize(target_size)
x = image.img_to_array(img)
x = np.expand_dims(x, axis=0)
x = pak.preprocess_input(x)
preds = model.predict(x)
return pak.decode_predictions(preds, top=top_n)[0]
def classify_and_plot(image_path):
"""Classify an image with different models.
Plot it and its predicitons.
"""
img = Image.open(image_path)
resnet_preds = image_classify(resnet_model, resnet50, img)
xception_preds = image_classify(xception_model, xception, img)
inception_preds = image_classify(inception_model, inception_v3, img)
cv_img = cv2.imread(image_path)
preds_arr = [('Resnet50', resnet_preds), ('xception', xception_preds), ('Inception', inception_preds)]
return (img, cv_img, preds_arr)
img = load_img('./imgs/chapter7/img_example.jpeg', target_size=(224, 224))
inception_preds = image_classify(inception_model, inception_v3, img)
inception_preds
WARNING:tensorflow:8 out of the last 11 calls to <function Model.make_predict_function.<locals>.predict_function at 0x7f9111cb1ee0> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details.
[('n03933933', 'pier', 0.9737361),
('n03216828', 'dock', 0.0070415554),
('n09332890', 'lakeside', 0.0041139866)]
2.SIFT特征
SIFT特征是一种用于检测和描述数字图像中的局部特征的算法,SIFT特征是可以对抗不同变换(即同一个特征在不同变换下可能看起来不同)保持不变,也是在Deep Learning之前最为流行的算法。SIFT特征点提取较为方便,提取速度较快,对于图像的缩放等变换比较鲁棒。
sift = cv2.SIFT_create()
img = cv2.imread('./imgs/chapter7/img_example.jpeg')
kp, des = sift.detectAndCompute(img, None)
img_kp = cv2.drawKeypoints(img, kp, img)
plt.figure(figsize=(10, 10))
plt.imshow(img_kp);
3.SURF特征
SIFT是非常好的特征,但是它的计算是比较缓慢的,为了提升SIFT的计算速度,Bay, H., Tuytelaars, T.以及Van Gool, L三人提出了一种新的算法叫做SURF(Speeded-Up Robust Features)。顾名思义,它是SIFT的加速版。
分析表明,SURF它比SIFT快3倍,性能与SIFT相当。SURF擅长处理模糊和旋转的图像,但不擅长处理视点变化和光照变化。
### 因为版权问题,有的版本里面没有SURF了,如果要使用需要对版本进行调整,
### SURF implementations are no longer included in the OpenCV 3 library by default.
surf = cv2.xfeatures2d.SURF_create(400)
# Find keypoints and descriptors directly
kp, des = surf.detectAndCompute(img,None)
4.ORB特征
ORB特征在计算成本、匹配性能等方面都是SIFT和SURF的一个很好的替代选择。因为SIFT和SURF都是有专利的,使用的话需要付费。但是ORB特征则不需要。ORB融合了快速关键点检测器和BRIEF描述符,并进行了许多改进以提高性能。首先利用FAST算法寻找关键点,然后利用Harris角点测度寻找关键点中的前N个点,此外它还使用金字塔生成多尺度特征。
orb = cv2.ORB_create() # OpenCV 3 backward incompatibility: Do not create a detector with `cv2.ORB()`.
key_points, description = orb.detectAndCompute(img, None)
img_building_keypoints = cv2.drawKeypoints(img,
key_points,
img,
flags=cv2.DRAW_MATCHES_FLAGS_DRAW_RICH_KEYPOINTS) # Draw circles.
plt.figure(figsize=(10, 10))
plt.title('ORB Interest Points')
plt.imshow(img_building_keypoints);
5.FAST特征
算法如SIFT、SURF提取到的特征非常优秀(有较强的不变性),但时间消耗较大大,如果实践使用的话可能无法满足我们的需求。Edward Rosten和Tom Drummond在2006年《Machine learning for high-speed corner detection》一文中提出了FAST特征点,并在2010年稍作修改后发表了《Features From Accelerated Segment Test》,简称FAST。
path = './imgs/chapter7/img_example.jpeg'
img = cv2.imread(path)
fast = cv2.FastFeatureDetector_create(40)
# find and draw the keypoints
kp = fast.detect(img,None)
img2 = cv2.drawKeypoints(img, kp, None, color=(255,0,0))
# Print all default params
print( "Threshold: {}".format(fast.getThreshold()) )
print( "nonmaxSuppression:{}".format(fast.getNonmaxSuppression()) )
print( "neighborhood: {}".format(fast.getType()) )
print( "Total Keypoints with nonmaxSuppression: {}".format(len(kp)) )
Threshold: 40
nonmaxSuppression:True
neighborhood: 2
Total Keypoints with nonmaxSuppression: 1483
plt.figure(figsize=(10, 10))
plt.title('FAST Interest Points')
plt.imshow(img2);
6.BEBLID特征
OpenCV 4.5.1中最令人兴奋的特性之一是BEBLID,它是一种新的描述符,它可以在减少执行时间的同时提高图像匹配精度。BEBLID是2020年引入的一种新的描述符,它已经被证明可以在多个任务中改善ORB。由于BEBLID适用于多种检测方法,因此必须将ORB关键点的比例设置为0.75~1。
在《Improving your image matching results by 14% with one line of code》作者的对比实验中,发现使用BEBLID描述符可以获得77.57%的inliers。如果我们在description单元格中注释BEBLID并使用ORB descriptor,结果将下降到63.20%:
import cv2
# Comment or uncomment to use ORB or BEBLID
path = './imgs/chapter7/img_example.jpeg'
img = cv2.imread(path)
detector = cv2.ORB_create(10000)
kpts1 = detector.detect(img, None)
descriptor = cv2.xfeatures2d.BEBLID_create(0.75)
kpts1, desc1 = descriptor.compute(img, kpts1)
7.图像聚集特征
扫描所有图片进行匹配,按照自定义的规则寻找出匹配上最多的图像的ID作为新的特征,用于寻找最近邻的图像。
下面以SIFT特征为例,上面的SURF,ORB,FAST,BEBLID等也都可以用来寻找最近邻信息。
import numpy as np
import cv2
from matplotlib import pyplot as plt
'''
https://opencv-python-tutroals.readthedocs.io/en/latest/py_tutorials/py_feature2d/py_matcher/py_matcher.html#matcher
'''
img1 = cv2.imread('./imgs/chapter7/img_example.jpeg',0) # queryImage
img2 = cv2.imread('./imgs/chapter7/Apple.png',0) # trainImage
# Initiate SIFT detector
sift = cv2.SIFT_create()
# find the keypoints and descriptors with SIFT
kp1, des1 = sift.detectAndCompute(img1,None)
kp2, des2 = sift.detectAndCompute(img2,None)
# FLANN parameters
FLANN_INDEX_KDTREE = 0
index_params = dict(algorithm = FLANN_INDEX_KDTREE, trees = 5)
search_params = dict(checks=50) # or pass empty dictionary
lann = cv2.FlannBasedMatcher(index_params,search_params)
matches = flann.knnMatch(des1,des2,k=2)
# Need to draw only good matches, so create a mask
matchesMask = [[0,0] for i in range(len(matches))]
# ratio test as per Lowe's paper
for i,(m,n) in enumerate(matches):
if m.distance < 0.7*n.distance:
matchesMask[i]=[1,0]
draw_params = dict(matchColor = (0,255,0),
singlePointColor = (255,0,0),
matchesMask = matchesMask,
flags = 0)
img3 = cv2.drawMatchesKnn(img1,kp1,img2,kp2,matches,None,**draw_params)
plt.figure(figsize=(10, 10))
plt.imshow(img3,)
<matplotlib.image.AxesImage at 0x7f9163d8a520>