论文链接:AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE
论文代码:https://github.com/google-research/vision_transformer
目录
1、Abstract and background
2、method
2.1、VISION TRANSFORMER (VIT)
2.2、FINE-TUNING AND HIGHER RESOLUTION
3、EXPERIMENTS
3.1、Setup
4、COMPARISON TO STATE OF THE ART(SOTA)
5、Pre-train data requirement
6、SCALING STUDY
7、INSPECTING VISION TRANSFORMER
8、SELF-SUPERVISION
9、conclusion
1、Abstract and background
图像分成块序列然后输入到Transformer中执行图像分类任务。
将一幅图像分割为多个patch(图像块),并将这些patch的线性嵌入序列作为Transformer的输入。图像块与NLP中的token(单词)的处理方式相同。然后采用有监督的方式对模型进行图像分类训练。
图1、总体流程图
linear projection)每个块,并添加位置嵌入(position embedding),并将生成的矢量序列输入到Transformer编码器中。为了执行分类,向序列中添加额外的可学习“分类标记”。这个视觉transformer encoder的灵感来源文献:Attention is all you need。
文中提到在中等训练数据集中,如果没有采用强正则化,它的效果不如ResNet。
但是如果在大规模的数据集中训练,就不会出现这种情况。并且将训练好的Vit模型进行迁移训练也可以取得很好的效果。
在ImageNet-21k数据集或in-house JFT-300M数据集上进行预训练时,ViT在多个图像识别基准上接近或超过了最新水平。最佳模型在ImageNet上的精度达到88.55%,在ImageNet ReaL上的精度达到90.72%,在CIFAR-100上的精度达到94.55%,在VTAB等19项任务上的精度达到77.63%。
主要原因:Transformer 缺乏 CNN 固有的一些归纳偏差,例如:translation equivariance 和 locality。所以当训练数据不足的情况下,效果不好。
什么是translation equivariance?(平移不变性)
具体可以参考一下面的动态图,左边是一张数字图像,随着左边图像中数字4的位置发生一定量的偏移,右边特征图的位置也发生一定量的偏移。
图2、平移不变性
如果还不能理解,可以再看一下面这张图像:
输入图像 X1,显示数字“4”,向右平移,得到输入图像 X2。F1 和 F2 分别是通过平移等变映射计算得到的特征图。在这种情况下,通过将 X2 传递给 φ 获得的特征图 F2 等效于通过将相同的平移 T 应用于特征图 F1 获得的特征图,该平移 T 也已应用于 X1 以获得 X2。以上这些都是CNN所特有的。主要在于CNN中的卷积是唯一的线性和平移不变的算子。虽然卷积是平移等变的而不是不变的,但通过将卷积与空间池化算子相结合,可以在神经网络中实现近似的平移不变性。
2、method
ViT(视觉transformer)尽可能遵循最开始的NLP中的transformer中的结构,最主要的优点就是可以开箱使用,比较方便。
2.1、VISION TRANSFORMER (VIT)
完整的ViT结构如图1总流程图,原始的transformer是采用1D的token序列作为输入。
为了能够处理2D图像,将图像 X :格式 {H×W×C} 调整为2D的图像块Xp:{N×(P^2*C)},其中(H,W)是原始图像的分辨率,C是通道数,(P,P)是每个图像块的分辨率,N=H*W/p^2是产生的图像块数量,N也是Transformer的有效输入序列长度。
Transformer所有层中使用恒定的潜在向量大小D,然后将图像块(patches)展平,使用可训练的线性投影将其映射到D维(如下图公式(1)所示),然后再将投影的输出称为:patch embedding
和BERT的[class] token类似,为Embedding patch序列(
)预先准备了一个可学习的嵌入,其在Transformer编码器(
)输出端的状态用作图像表示y(公式(4))。在 pre-train 和 fine-tuning 期间,将一个分类头(head)连接到
。分类头在预训练时由一个带有一个隐藏层的MLP实现,在微调时由一个线性层实现。
位置嵌入(position embedding)添加到 patch embedding 以保留位置信息。并且使用标准的可学习一维位置嵌入,因为使用更先进的二维感知位置嵌入带来的显著性能提升(附录D.4)。生成的嵌入向量序列作为编码器的输入。
公式(2)、(3))层组成。在每个块之前应用LayerNorm(LN),在每个块之后应用残差连接。MLP包含两个具有GELU的非线性层。
Inductive bias:transformer比CNN具有更少的图像特异性归纳偏置。在CNN中,局部性、二维邻域结构和平移不变性被输入到整个模型的每一层中。在ViT中,只有MLP层是局部的和平移不变的,而自注意力层(self-attention)是全局的。二维邻域结构的使用非常少:在模型开始时,通过将图像切割成小块,并在微调时调整不同分辨率图像的位置嵌入(如下所述)。除此之外,初始化时的位置嵌入没有关于图像块的二维位置信息,并且图像块之间的所有空间关系都必须从头开始学习。
Hybrid Architecture:作为原始图像块的替代,输入序列可以由CNN的特征图生成。在这个混合模型中,将图像块嵌入投影E(公式(1))应用于从CNN特征图中提取的patches。patches可以具有空间大小1x1,这意味着通过简单地展平特征图的空间维度并投影到Transformer维度来获得输入序列。如上所述,添加了分类输入嵌入和位置嵌入。
2.2、FINE-TUNING AND HIGHER RESOLUTION
在大型数据集上预先训练ViT,并对(较小的)下游任务进行微调。为此,我们移除预先训练好的预测头,并附加一个初始化为零的D×K前馈层,其中K是下游类的数量。以比训练前更高的分辨率进行微调通常是有益的。当输入更高分辨率的图像时,保持patch大小不变,这会导致更大的有效序列长度。Vision Transformer可以处理任意序列长度(最多可达内存限制),但是,预先训练的位置嵌入可能不再有意义。因此,根据预训练位置嵌入在原始图像中的位置,对其执行2D插值。这种分辨率调整和patch提取是将图像2D结构的归纳偏差手动输入到ViT中。
3、EXPERIMENTS
文中评估了ResNet、Vision Transformer(ViT)和hybrid的表征学习能力。为了理解每个模型的数据需求,对不同大小的数据集进行预训练,并评估许多基准任务。当考虑预训练模型的计算成本时,ViT表现非常好,以较低的预训练成本在大多数识别基准上达到了最先进的水平。最后,使用自监督进行了一个小实验,实验证明自监督的ViT很有前景。
3.1、Setup
DataSet:为了探索模型的可扩展性,本文使用ILSVRC-2012 ImageNet数据集,该数据集包含1k类和130万幅图像(下文中称为ImageNet)。并对训练数据集和下游任务的测试集进行去重操作。再将在这些数据集上训练的模型转移到几个基准任务中:原始验证标签和清理后的真实标签上的ImageNet、CIFAR-10/100、Oxford IIIT Pets和Oxford Flowers-102。对于这些数据集都采用了相应的预处理。
评估了19项任务VTAB分类。VTAB评估不同任务的低数据传输,每个任务使用1000个训练样例。这些任务分为三组:自然任务——如上述的基准任务、Pets、CIFAR等。专业任务——医学和卫星图像,以及结构化任务——需要几何理解的任务,如定位。
这里的VTAB(Visual Task Adaptation Benchmark)是谷歌推出的“视觉领域任务自适应基准”:
VTAB 方案首先将一种算法 (A) 应用到大量独立的常见视觉任务中。该算法可以利用上游数据进行预训练,以生成包含视觉表征的模型。但其必须同时定义适应性策略,使用少量样本对下游任务进行训练,执行特定任务并返回预测模型。该算法的最终分数是它在各任务中的平均测试分数。
具体可以参考下图:
Model Variants:本文将ViT配置与BERT的配置保持一致,如Table 1所示。“Base”和“Large”模型直接采用了BERT模型,本文添加了更大的“Huge”模型。下文使用简短的符号来表示模型大小和输入patch大小:例如,ViT-L/16表示具有16×16输入patch大小的“Large”变体。其中transformer的序列长度与patch大小的平方成反比,因此patch越小模型的计算成本就越高。
本文采用ResNet作为基线CNN,但将批量归一化层(batch normalization)替换为组归一化层(group normalization),并使用标准化卷积。这些修改改善了Transfer,将修改后的模型称为“ResNet(BiT)”。
对于Hybirds,将中间特征图以一个“像素”的patch大小提供给ViT。为了对不同的序列长度进行实验,要么(i)获取常规ResNet50的第4阶段的输出,要么(ii)移除第4阶段,在第3阶段放置相同数量的层(保持总层数),然后采用扩展的第3阶段的输出。选项(ii)导致序列长度延长4倍,ViT模型更费时。
Training & Fine-tuning:采用β1=0.9、β2=0.999、batchsize为4096的Adam对所有模型(包括Resnet)进行训练,并应用0.1的high weight decay,这对所有模型的传输非常有用(Adam对Resnet的效果略好于SGD)。
使用线性学习率warmup和decay。对于微调,使用带有momentum的SGD,batchsize为512,对于所有模型,请参见附录B.1.1,如下图所示。
table 2中的ImageNet结果,采用更高的分辨率进行了微调:ViT-L/16为512,ViT-H/14为518,并且使用Polyak&Juditsky(1992:Acceleration of Stochastic Approximation by Averaging)的平均值,系数为0.9999。
Metrics:本文展示了Few-shot和fine-tune在下游数据集的精度。微调精度在各自数据集上微调后得到提升。Few-shot精度通过解决一个正则化最小二乘回归问题,将训练图像子集的(冻结)表示映射到{−1,1}^K目标向量。这个公式允许我们以封闭形式恢复精确解。虽然主要关注微调性能,但有时会使用线性Few-shot精度进行快速动态评估,因为微调成本太高。
4、COMPARISON TO STATE OF THE ART(SOTA)
本文将大的模型ViT-H/14和ViT-L/16与最先进的CNN进行比较。
table 2所示:
从table 2中可以看出,在JFT-300M上预训练的小模型ViT-L/16在所有任务上的效果都优于BiT-L(ResNet152×4)(在同一个数据集上训练的),同时训练所需的计算资源也大大减少。更大的模型ViT-H/14进一步提高了性能,尤其是在更具挑战性的数据集上—ImageNet、CIFAR-100和VTAB。同时该模型的预训练所需的计算量仍然大大减少。Pre-train的效率不仅受到模型架构选择的影响,还受到其他参数的影响,例如training schedule、优化器、权重衰减等。
在公共ImageNet-21k数据集上预训练的ViT-L/16模型在大多数数据集上也表现良好,同时预训练所需的资源较少:它可以在大约30天内使用8个核心的标准云TPUv3进行训练。
Figure 2将VTAB任务分解为各自的组,并与该基准上以前的SOTA方法进行比较:BiT、VIVI(在ImageNet和Youtube上共同培训的ResNet),以及S4L(在ImageNet上的监督加半监督学习)。在自然任务和结构化任务上,ViT-H/14的性能优于BiT-R152x4和其他方法。在specialized上,前两个模型的性能相似。
5、Pre-train data requirement
Vision Transformer在对大型JFT-300M数据集进行预训练时表现良好。与resnet相比,视觉的归纳偏差更少,那么数据集的大小有多重要?文中进行了两个系列的实验。
首先,在大的数据集上预训练ViT模型:ImageNet、ImageNet-21k和JFT300M。为了在较小的数据集上提高性能,选择了三个基本的正则化参数—权重衰减、dropout和标签平滑。
Figure 3显示了微调到ImageNet后的结果(其他数据集的结果如Table 5所示,微调过程中分辨率的提高提高了性能)。当在最小的数据集ImageNet上进行预训练时,ViT大型模型的性能不如ViT基础模型,尽管(适度)正则化。在ImageNet-21k预训练中,他们的表现相似。只有使用JFT-300M,才能看到large models的全部好处。
Figure 3同时展示了不同大小BiT模型跨越的区域的性能表现,BiT CNN在ImageNet上的性能优于ViT,但在更大的数据集上,ViT表现最好。
其次,在9M、30M和90M的随机子集以及完整的JFT300M数据集上训练ViT模型。本文不会对较小的子集执行额外的正则化,并对所有设置使用相同的超参数。通过这种方式评估内在的模型属性,而不是正则化的效果。使用early stopping,并在训练期间打印达到的最佳验证的准确率。为了节省计算量,采用few-shot的线性精度,而不是全微调精度。
Figure 4展示了结果。在较小的数据集上,Vision Transformers的计算成本比Resnet更高。例如,ViT-B/32比ResNet50稍快;它在9M子集上的性能要差得多,但在90M+子集上的性能更好。ResNet152x2和ViT-L/16也是如此。这一结果强化了一种直觉,即卷积归纳偏差对于较小的数据集是有用的,但对于较大的数据集,直接从数据中学习相关模式就足够了,甚至是有益的。
Figure 4)以及VTAB上的低数据结果(Table 2)似乎有望实现非常低的数据传输。进一步分析ViT的few-shot特性是未来工作的一个令人兴奋的方向。
6、SCALING STUDY
通过评估JFT-300M的Transfer性能,对不同的模型进行scale研究。在这种情况下,数据大小不会限制模型的性能,我们会评估每个模型的性能和预训练成本。
该模型集包括:7个RESNET,R50x1、R50x2、R101x1、R152x1、R152x2,预训练7个epoch,加上R152x2和R200x3预训练14个epoch;6个transformer,ViT-B/32、B/16、L/32、L/16,预训练7个epoch,加上预训练14个epoch的L/16和H/14;5个hybirds,R50+ViT-B/32、B/16、L/32、L/16,预训练7个epcoh,再加上R50+ViT-L/16,预训练14个epoch(对于hybirds,模型名称末尾的数字代表的不是patch的大小,而是ResNet主干中的总采样率)。
图5显示了Transfer性能与总pre-training计算结果的对比。针对不同体系结构的性能与pre-training计算:transformer、resnet和hybird。在计算预算相同的情况下,Transformer的性能通常优于resnet。对于较小的patch尺寸,hybird优于transformer,但对于较大的patch,差距消失了。
Figure 5
每个模型的详细结果如表6所示。首先,transformer在性能/计算权衡方面占据主导地位。ViT使用大约2个− 4×更少的计算以达到相同的性能(平均超过5个数据集)。其次,Hybird在较小的计算预算下略优于ViT,但在较大的模型中,这种差异消失了。这个结果有点令人惊讶,因为人们可能期望卷积局部特征处理在任何大小的ViT中都能起到辅助作用。第三,transformer似乎不会在尝试的范围内饱和,从而推动未来的扩展工作。
7、INSPECTING VISION TRANSFORMER
为了开始理解transformer如何处理图像数据,本文分析了它的内部表示。transformer的第一层线性地将平坦的patch投影到较低维度的空间中(公式1)。图7(左)显示了学习的嵌入滤波器的顶部主要组件。这些成分类似于每个patch内精细结构的低维表示的合理基函数。
投影后,将学习的位置嵌入添加到patch表示中。图7(中间)展示了该模型学习在位置嵌入的相似性中对图像内的距离进行编码,即越近的patch往往具有更相似的位置嵌入。此外,行-列结构出现;同一行/列中的patch具有类似的嵌入。最后,对于较大的网格,正弦结构有时很明显(附录D)。位置嵌入学习表示2D图像拓扑,这解释了为什么手工制作的2D感知嵌入变体没有提高(附录D.4)。
自注意力使ViT能够在整个图像中整合信息,即使是在最底层。我们调查网络在多大程度上利用了这种能力。具体来说,根据注意力权重计算图像空间中整合信息的平均距离(图7,右图)。这种“注意距离”类似于CNN中的感受野大小。一些头部关注的是已经位于最底层的大部分图像,这表明该模型确实使用了全局集成信息的能力。其他注意头在低层的注意距离一直很小。这种高度局部化的关注在Tranformer之前应用ResNet的hybird模型中不太明显(图7,右图),这表明它可能与CNN中的早期卷积层具有类似的功能。此外,注意距离随着网络深度的增加而增加。在全球范围内,我们发现该模型关注与分类语义相关的图像区域(图6)。
8、SELF-SUPERVISION
Transformers在NLP任务中表现出色。然而,他们的成功不仅源于出色的可扩展性,还源于大规模的自我监督预训练。本文还模拟了BERT中使用的masked patch 预测模型,对用于自我监督的masked patch预测进行了初步探索。通过自监督预训练,对于较小的ViT-B/16模型在ImageNet上实现了79.9%的准确率,与从头开始的训练相比,显著提高了2%,但仍落后于监督预训练4%。附录B.1.2包含更多详细信息。
9、conclusion
Transformer在图像识别中的直接应用与以前在计算机视觉中使用自我注意的工作不同,除了最初的patch提取步骤外,没有在体系结构中引入特定于图像的感应偏差。相反,将图像解释为一系列patch,并使用NLP中使用的标准Transformer编码器对其进行处理。这种简单但可扩展的策略在与大型数据集的预训练相结合取得很好的效果。因此,Vision Transformer在许多图像分类数据集上都达到或超过了最先进的水平,同时预训练成本相对较低。
虽然这些初步结果令人鼓舞,但仍存在许多挑战。一种是将ViT应用于其他计算机视觉任务,如检测和分割。另一个挑战是继续探索自监督的预训练方法。初步实验表明,自监督预训练效果有所改善,但自监督预训练与大规模监督预训练之间仍存在较大差距。最后,进一步扩展ViT可能会提高性能。
10、代码部分
算法总体流程图
代码总流程:代码中有注释!
注意:实际代码并不是按照下面这个顺序来的!
1、导入所需的包
import os
import math
import numpy as np
import pickle as p
import tensorflow as tf
import pandas as pd
from tensorflow import keras
import matplotlib.pyplot as plt
from tensorflow.keras import layers
import tensorflow_addons as tfa
%matplotlib inline
# 使用GPU
from tensorflow.compat.v1.keras import backend as K
config = tf.compat.v1.ConfigProto()
config.gpu_options.per_process_gpu_memory_fraction = 0.7
config.gpu_options.visible_device_list = "0,1"
sess = tf.compat.v1.Session(config=config)
K.set_session(sess)
2、打印tensorflow的版本
print("tensorflow版本:",tf.__version__)
print("tensorflow地址:",tf.__path__)
3、加载训练数据集:返回images和对应的labels
这里是cifar10数据集,同样可以换成其他数据集。
参考博客:
from_tensor_sliceshttps://zhuanlan.zhihu.com/p/380141130
代码如下:
def load_CIFAR_batch(filename):
"""
filename:对应cifar10中的5个batch
return:获取对应的images和对应的labels
"""
with open(filename,'rb') as f:
# 一个样本由标签和图像数据组成
# (3072=32x32x3)
data_dict = p.load(f,encoding='bytes')
# images和对应的labels
images= data_dict[b'data'] #10000*3072
labels = data_dict[b'labels'] #10000*1
# 把原始数据结构调整为: BCWH
images = images.reshape(10000,3,32,32)
# tensorflow处理图像数据的结构:BWHC
# 调整数据的维度,把通道数据C移动到最后一个维度
images = images.transpose(0,2,3,1)
# 将list转换成数组形式
labels = np.array(labels)
return images,labels
def load_CIFAR_data(data_dir):
"""
data_dir:数据地址
return:返回完整的数据和对应的标签
"""
images_train=[]
labels_train=[]
for i in range(5):
# 5个batch数据
f = os.path.join(data_dir,'data_batch_%d'%(i+1))
print('loading ',f)
# 调用load_CIFAR_batch()获得批量的图像及其对应的标签
image_batch,label_batch = load_CIFAR_batch(f)
images_train.append(image_batch)
labels_train.append(label_batch)
# 将所有batch合并
Xtrain = np.concatenate(images_train)
Ytrain = np.concatenate(labels_train)
# 删除操作
del image_batch ,label_batch
Xtest,Ytest = load_CIFAR_batch(os.path.join(data_dir,'test_batch'))
print('finished loadding CIFAR-10 data')
# 返回训练集的图像和标签,测试集的图像和标签
return (Xtrain,Ytrain),(Xtest,Ytest)
# 加载数据和标签
data_dir = r'/content/transformer_classification/data/cifar-10-batches-py'
(x_train,y_train),(x_test,y_test) = load_CIFAR_data(data_dir)
# 将numpy数据的格式转换为tf所需的dataset格式
train_dataset = tf.data.Dataset.from_tensor_slices((x_train,y_train))
test_dataset = tf.data.Dataset.from_tensor_slices((x_test,y_test))
4、可视化训练数据集
# 显示训练集中的图像
import matplotlib.pyplot as plt
plt.imshow(x_train[8])
import matplotlib.pyplot as plt
plt.figure(figsize=(4, 4))
image = x_train[np.random.choice(range(x_train.shape[0]))] # 32*32*3
plt.imshow(image.astype("uint8"))
plt.title("origin image")
plt.axis("off")
# resize
resized_image = tf.image.resize(
tf.convert_to_tensor([image]), size=(image_size, image_size) #72*72
)
print(resized_image.shape)#(1,72,72,3)
patches = Patches(patch_size)(resized_image)
print(f"Image size: {image_size} X {image_size}")# 图像大小
print(f"Patch size: {patch_size} X {patch_size}")# patch块的大小
print(f"Patches per image: {patches.shape[1]}") # 一张图总共有多少patches,72/6**2=12*12=144
print(f"Elements per patch: {patches.shape[-1]}") # 108一个patch包含的像素个数
print(f"patches shapes:{patches.shape}") # (1,144,108)
n = int(np.sqrt(patches.shape[1]))
print(n) # 12=72/6
plt.figure(figsize=(4, 4))
for i, patch in enumerate(patches[0]):
ax = plt.subplot(n, n, i + 1)
patch_img = tf.reshape(patch, (patch_size, patch_size, 3))
plt.imshow(patch_img.numpy().astype("uint8"))
plt.axis("off")
5、数据增强模块
data_augmentation = keras.Sequential(
[
layers.experimental.preprocessing.Normalization(), #归一化
layers.experimental.preprocessing.Resizing(image_size, image_size),#调整图像尺寸
layers.experimental.preprocessing.RandomFlip("horizontal"),#随机水平翻转
layers.experimental.preprocessing.RandomRotation(factor=0.02),#旋转
layers.experimental.preprocessing.RandomZoom(height_factor=0.2,width_factor=0.2)#随机调整图像大小
],
name="data_augmentation",
)
print(data_augmentation)
# 使预处理层的状态与正在传递的数据相匹配
# 计算训练数据集的均值和方差
print(data_augmentation.layers[0].adapt(x_train))
6、模型搭建
num_classes = 10
input_shape = (32, 32, 3)
learning_rate = 0.001
weight_decay = 0.0001
batch_size = 128
num_epochs = 10
image_size = 72 # 调整的图像大小
patch_size = 6 # transformer中的图像patches
num_patches = (image_size // patch_size) ** 2 # 一张图根据patches的大小计算总的个数
projection_dim = 64 # projection
num_heads = 4 # head
transformer_units = [
projection_dim * 2,
projection_dim,
] # Size of the transformer layers
transformer_layers = 8
mlp_head_units = [2048, 1024] # Size of the dense layers of the final classifier
def mlp(x, hidden_units, dropout_rate):
# 多层感知机
"""
hidden_unints:隐藏层个数
dropout_rate:设置对应的个数
"""
for units in hidden_units:
# mlp_head_units
x = layers.Dense(units, activation=tf.nn.gelu)(x)
x = layers.Dropout(dropout_rate)(x)
return x
# patchs类别
class Patches(layers.Layer):
def __init__(self, patch_size):
# patch_size:图像块的大小,不同的设置会得到不同的图像块数量
super(Patches, self).__init__()
self.patch_size = patch_size
def call(self, images):
# 图像的W
batch_size = tf.shape(images)[0]
# 提取图像并划分为多个图像子块
# 参考博客:
# 返回的是:返回4维tensor,数据类型与输入的images类型相同
patches = tf.image.extract_patches(
images=images,# 这个必须满足[batch, in_rows, in_cols, depth]
sizes=[1, self.patch_size, self.patch_size, 1],# patch的大小,[1, size_rows, size_cols, 1]
strides=[1, self.patch_size, self.patch_size, 1],# patch的移动步长[1, stride_rows, stride_cols, 1]
rates=[1, 1, 1, 1],# [1, rate_rows, rate_cols, 1],表示隔几个像素点,取一个像素点,直到满足sizes
padding="VALID",# 表示所取的patch区域必须完全包含在原始图像中.还有一种是"same"是对超出的图像部分通过补0实现
)
patch_dims = patches.shape[-1] # 返回最后一个维度
patches = tf.reshape(patches, [batch_size, -1, patch_dims])
return patches
class PatchEncoder(layers.Layer):
def __init__(self, num_patches, projection_dim):
super(PatchEncoder, self).__init__()
self.num_patches = num_patches
#一个全连接层,其输出维度为projection_dim,没有指明激活函数
self.projection = layers.Dense(units=projection_dim)
#定义一个嵌入层,这是一个可学习的层
#输入维度为num_patches,输出维度为projection_dim
self.position_embedding = layers.Embedding(
input_dim=num_patches, output_dim=projection_dim
)
def call(self, patch):
# start->limit(不包括这个),增量为1
positions = tf.range(start=0, limit=self.num_patches, delta=1)
# patch + position_embedding
encoded = self.projection(patch) + self.position_embedding(positions)
return encoded
7、训练
def create_vit_classifier():
inputs = layers.Input(shape=input_shape)
# 数据增强
augmented = data_augmentation(inputs)
# 创建图像的patches.
patches = Patches(patch_size)(augmented)
# 对上述创建的patches进行编码
encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)
# Create multiple layers of the Transformer block.
# 创建transformer块的多个层
# _表示一个占位符,表示不在意使用这个值,只是用于循环
for _ in range(transformer_layers):
# 对应总流程中的Norm
x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
# 创建一个多头注意力层
attention_output = layers.MultiHeadAttention(
num_heads=num_heads, key_dim=projection_dim, dropout=0.1
)(x1, x1)
# Skip connection:对应transformer块中的第一个加号
x2 = layers.Add()([attention_output, encoded_patches])
# 对应第二个Norm
x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
# MLP.
x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.1)
# Skip connection 2.
encoded_patches = layers.Add()([x3, x2])
# Create a [batch_size, projection_dim] tensor.
representation = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
representation = layers.Flatten()(representation)
representation = layers.Dropout(0.5)(representation)
# Add MLP.
features = mlp(representation, hidden_units=mlp_head_units, dropout_rate=0.5)
# Classify outputs:这个参数对应的是数据集中的类别数
logits = layers.Dense(num_classes)(features)
# 采用keras中的Model进行封装
model = keras.Model(inputs=inputs, outputs=logits)
return model
def run_experiment(model):
optimizer = tfa.optimizers.AdamW(
learning_rate=learning_rate, weight_decay=weight_decay
)
model.compile(
optimizer=optimizer,
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[
keras.metrics.SparseCategoricalAccuracy(name="accuracy"),# 稀疏情况的分类准确率
keras.metrics.SparseTopKCategoricalAccuracy(5, name="top-5-accuracy"),
],
)
#checkpoint_filepath = r".\tmp\checkpoint"
checkpoint_filepath ="model_bak.hdf5"
# 查看训练模型的状态
checkpoint_callback = keras.callbacks.ModelCheckpoint(
checkpoint_filepath,
monitor="val_accuracy",
save_best_only=True,
save_weights_only=True,
)
# tensorflow提供的可视化工具
# usage:tensorboard --logdir=/logs
tensorboard_callback = keras.callbacks.TensorBoard(log_dir='./logs',#日志文件名
histogram_freq=0,#0表示不计算直方图
batch_size=32,
write_graph=True,# 是否可视化图像
write_grads=False,# 是否在 TensorBoard 中可视化梯度值直方图,前提是histogram_freq必须大于0
write_images=False, #是否在 TensorBoard 中将模型权重以图片可视化
embeddings_freq=0,
embeddings_layer_names=None, #被选中的嵌入层会被保存的频率(在训练轮中)
embeddings_metadata=None, #一个字典,对应层的名字到保存有这个嵌入层元数据文件的名字。
embeddings_data=None, #要嵌入在 embeddings_layer_names 指定的层的数据。
update_freq='epoch')
history = model.fit(
x=x_train,
y=y_train,
batch_size=batch_size,
epochs=num_epochs,
validation_split=0.1, # 直接采用训练集中的一部分数据作为验证集
callbacks=[checkpoint_callback],
)
# 加载训练好的模型用于评估
model.load_weights(checkpoint_filepath)
_, accuracy, top_5_accuracy = model.evaluate(x_test, y_test)
print(f"Test accuracy: {round(accuracy * 100, 2)}%")
print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")
return history
vit_classifier = create_vit_classifier()
history = run_experiment(vit_classifier)
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss =history.history['val_loss']
plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.ylabel('Accuracy')
plt.ylim([min(plt.ylim()),1.1])
plt.title('Training and Validation Accuracy')
plt.subplot(2, 1, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.ylabel('Cross Entropy')
plt.ylim([-0.1,4.0])
plt.title('Training and Validation Loss')
plt.xlabel('epoch')
plt.show()