CLIP 浅析


文章目录

  • CLIP 浅析
  • 概述
  • 如何训练CLIP
  • 如何使用Clip进行图像分类
  • 优缺点分析
  • 优点
  • 缺点


概述

CLIP的英文全称是Contrastive Language-Image Pre-training,即一种基于对比文本-图像对的预训练方法或者模型

如何训练CLIP

CLIP包括两个模型:Text EncoderImage Encoder,其中Text Encoder用来提取文本的特征,可以采用NLP中常用的text transformer模型;而Image Encoder用来提取图像的特征,可以采用常用CNN模型或者vision transformer。

其中CLIP的流程图如下

CLIP 浅析_数据集

首先CLIP通过一个文本编码器和图像编码器获得相关特征

CLIP 浅析_cnn_02

其中对于通过文本编码器获得的特征记为CLIP 浅析_图像分类_03 表示第CLIP 浅析_transformer_04个文本特征,其中共含有CLIP 浅析_图像分类_05个特征,CLIP 浅析_图像分类_05为训练数据集中的文本信息中的类别个数,对于通过图像编码器获得的特征记为CLIP 浅析_图像分类_07 表示第CLIP 浅析_transformer_04个图像特征,并将CLIP 浅析_图像分类_07与每一个文本特征CLIP 浅析_图像分类_03进行余弦相似度计算。并使用softmax计算概率得到最相似的图文匹配对。其中伪代码如下

# image_encoder - ResNet or Vision Transformer
# text_encoder - CBOW or Text Transformer
# I[n, h, w, c] - minibatch of aligned images
# T[n, l] - minibatch of aligned texts
# W_i[d_i, d_e] - learned proj of image to embed
# W_t[d_t, d_e] - learned proj of text to embed
# t - learned temperature parameter

# 分别提取图像特征和文本特征
I_f = image_encoder(I) #[n, d_i]
T_f = text_encoder(T) #[n, d_t]

# 对两个特征进行线性投射,得到相同维度的特征,并进行l2归一化
I_e = l2_normalize(np.dot(I_f, W_i), axis=1)
T_e = l2_normalize(np.dot(T_f, W_t), axis=1)

# 计算缩放的余弦相似度:[n, n]
logits = np.dot(I_e, T_e.T) * np.exp(t)

# 对称的对比学习损失:等价于N个类别的cross_entropy_loss
labels = np.arange(n) # 对角线元素的labels
loss_i = cross_entropy_loss(logits, labels, axis=0)
loss_t = cross_entropy_loss(logits, labels, axis=1)
loss = (loss_i + loss_t)/2

为了训练CLIP,OpenAI从互联网收集了共4个亿的文本-图像对,论文称之为WebImageText

如何使用Clip进行图像分类

因为ImageNet中的label全是图像类别的表情,为了更好的适应Transformer,作者使用了A photo of {label}的句子作为输入。

CLIP 浅析_编码器_11

伪代码如下

# 首先生成每个类别的文本描述
labels = ["dog", "cat", "bird", "person", "mushroom", "cup"]
text_descriptions = [f"A photo of a {label}" for label in labels]
text_tokens = clip.tokenize(text_descriptions).cuda()

# 提取文本特征
with torch.no_grad():
    text_features = model.encode_text(text_tokens).float()
    text_features /= text_features.norm(dim=-1, keepdim=True)

然后通过计算余弦相似度,并使用softmax计算概率得到最相似的图文匹配对。

优缺点分析

优点

因为CLIP使用图文对的形式进行训练,所以可以从互联网上获得大量的数据进行训练,从而无需大量的人工标注。因为大数据集的原因使得CLIP与CV中常用的先预训练然后微调不同,CLIP可以直接实现zero-shot的图像分类,即不需要任何训练数据,就能在某个具体下游任务上实现分类。

同时OpenAI证明了CLIP在许多数据集上与ResNet50具有相似甚至更好的准确率。

CLIP 浅析_编码器_12

CLIP 浅析_cnn_13

缺点

因为CLIP采用了大量的数据集以及复杂的视觉结构使得它需要消耗恐怖的计算资源。同时对于训练集外的数据,无法做到很好的预测。