之前面试的时候被问到了这个问题,遂总结一下看到的论文中的多模态对齐方式。

LLaVA

LLaVA 主要做的是 VQA 任务,即输入一张图片和一个问题,LLaVA 可以给出文本的答案。因此 LLaVA 主要涉及两个模态:图像和文本。

【持续更新】总结所有的多模态大模型的对齐方式_大模型

LLaVA 的对齐方式相对来说比较简单,只有简单的线性层。LLaVA 的模型架构如下图所示,LLM 选择的是 Vicuna,图像编码器选择的是 CLIP 的 ViT-L/14,中间增加了一个线性层 W 将图像特征转换为跟文本 Embedding 相同维度,再一起输入到 LLM 中。

【持续更新】总结所有的多模态大模型的对齐方式_多模态_02


可以通过下面的代码直接打印一下模型结构看一下,更加直观,mm_projector是由两个Linear层组成的。

from llava.model.builder import load_pretrained_model
from llava.mm_utils import get_model_name_from_path

model_path = "liuhaotian/llava-v1.5-7b"

tokenizer, model, image_processor, context_len = load_pretrained_model(
    model_path=model_path,
    model_base=None,
    model_name=get_model_name_from_path(model_path)
)
print(model)
LlavaLlamaForCausalLM(
  (model): LlavaLlamaModel(
    (embed_tokens): Embedding(32000, 4096, padding_idx=0)
    (layers): ModuleList(
	......
    )
    (norm): LlamaRMSNorm()
    (vision_tower): CLIPVisionTower(
	......
    )
    (mm_projector): Sequential(
      (0): Linear(in_features=1024, out_features=4096, bias=True)
      (1): GELU(approximate='none')
      (2): Linear(in_features=4096, out_features=4096, bias=True)
    )
  )
  (lm_head): Linear(in_features=4096, out_features=32000, bias=False)
)

Flamingo

Flamingo 主要做的是 Caption 任务,即输入一张图片,Flamingo 可以生成图片的标题。不同的是,Flamingo 可以输入多张图片,实现上下文学习的 Few-Shot 效果。因此 Flamingo 也主要涉及两个模态:图像和文本。

Flamingo 的模型架构如下图所示,首先通过冻结的视觉编码器对图像进行编码,然后通过一个可训练的感知重采样器(Perceiver Resampler)重新提取特征,输出一个固定数量的视觉 tokens,这些视觉 tokens 再通过交叉注意力层被用于预训练的语言模型的每一层(LM block)。

【持续更新】总结所有的多模态大模型的对齐方式_语言模型_03


Flamingo 中插入的 Perceiver Resampler 和 GATED XATTN-DENSE 都是重新初始化的,GATED XATTN-DENSE 主要是为了根据视觉输入调整 LM,在冻结的 LM 层之间插入新的交叉注意力层。这些交叉注意力层的 keys 和 values 是从视觉特征中获得的,而 queries 则是从语言输入中获得的。交叉注意力层后面跟的是 FFW,这些层都经过了门控(gated),可以让 LM 在初始化的时候保持完整,从而提高稳定性和性能。

【持续更新】总结所有的多模态大模型的对齐方式_计算机视觉_04

不过由于 DeepMind 的 Flamingo 是不开源的,没有办法直接打印模型结构,所以这里我们选择 Christoph Schuhmann 团队开源 OpenFlamingo,通过下面这个代码可以打印一下它的模型结构。

from open_flamingo import create_model_and_transforms

model, image_processor, tokenizer = create_model_and_transforms(
    clip_vision_encoder_path="ViT-L-14",
    clip_vision_encoder_pretrained="openai",
    lang_encoder_path="anas-awadalla/mpt-1b-redpajama-200b",
    tokenizer_path="anas-awadalla/mpt-1b-redpajama-200b",
    cross_attn_every_n_layers=1,
)
print(model)
Flamingo(
  (vision_encoder): VisionTransformer(
	......
  )
  (perceiver): PerceiverResampler(
    (layers): ModuleList(
      (0-5): 6 x ModuleList(
        (0): PerceiverAttention(
          (norm_media): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (norm_latents): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (to_q): Linear(in_features=1024, out_features=512, bias=False)
          (to_kv): Linear(in_features=1024, out_features=1024, bias=False)
          (to_out): Linear(in_features=512, out_features=1024, bias=False)
        )
        (1): Sequential(
          (0): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
          (1): Linear(in_features=1024, out_features=4096, bias=False)
          (2): GELU(approximate='none')
          (3): Linear(in_features=4096, out_features=1024, bias=False)
        )
      )
    )
    (norm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
  )
  (lang_encoder): MosaicGPT(
	......
    (gated_cross_attn_layers): ModuleList(
      (0-23): 24 x GatedCrossAttentionBlock(
        (attn): MaskedCrossAttention(
          (norm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (to_q): Linear(in_features=2048, out_features=512, bias=False)
          (to_kv): Linear(in_features=1024, out_features=1024, bias=False)
          (to_out): Linear(in_features=512, out_features=2048, bias=False)
        )
        (ff): Sequential(
          (0): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
          (1): Linear(in_features=2048, out_features=8192, bias=False)
          (2): GELU(approximate='none')
          (3): Linear(in_features=8192, out_features=2048, bias=False)
        )
      )
    )
  )
)

BLIP-2

BLIP-2 的论文中提出了一种新的视觉-语言模型预训练的方法—— Q-Former,主要分为两个阶段:① 基于冻结的图像编码器进行视觉-语言表征学习;② 基于冻结的 LLM 进行视觉-语言生成学习。Q-Former 是一个可训练的模块,通过 BERT Base 来初始化权重,用来连接冻结的图像编码器和冻结的 LLM。对于不同分辨率的图像,Q-Former 都可以通过图像编码器提取固定数量的输出特征。Q-Former 主要包括两个 Transformer 子模块,① 图像 Transformer 用于跟冻结的图像编码器交互,提取视觉特征;② 文本 Transformer 可以既作为文本编码器和文本解码器。

  1. 视觉-语言表征学习:创建了一批可学习的 query embeddings 作为图像 Transformer 的输入,queries 可以通过自注意力层跟自己或者文本进行交互,也可以通过交叉注意力层跟冻结的图像编码器提取的特征进行交互。
    视觉-语言表征学习通过三种任务进行训练:
  • 图文对比学习:对齐图像表征和文本表征。
  • 图文匹配:判断图文对是否匹配的二分类任务。
  • 基于图像的文本生成:基于图像生成标题。
  1. 视觉-语言生成学习:基于训练好的 Q-Former 模块和可学习的 query embeddings 提取图像特征,然后用全连接层将 Q-Former 的输出维度跟 LLM 的输入维度进行对齐,最后再输入到 LLM 中。

    可以通过下面这段代码打印一下 BLIP-2 的模型结构。
import torch

from lavis.models import load_model_and_preprocess

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model, vis_processors, _ = load_model_and_preprocess(
    name="blip2_opt", model_type="pretrain_opt2.7b", is_eval=True, device=device
)
print(model)
Blip2OPT(
  (visual_encoder): VisionTransformer(
    ......
  )
  (ln_vision): LayerNorm((1408,), eps=1e-05, elementwise_affine=True)
  (Qformer): BertLMHeadModel(
    (bert): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): None
        (position_embeddings): None
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0-11): 12 x: BertLayer(
            (attention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=768, out_features=768, bias=True)
                (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
            )
            (crossattention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=1408, out_features=768, bias=True)
                (value): Linear(in_features=1408, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=768, out_features=768, bias=True)
                (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
            )
            (intermediate): None
            (output): None
            (intermediate_query): BertIntermediate(
              (dense): Linear(in_features=768, out_features=3072, bias=True)
              (intermediate_act_fn): GELUActivation()
            )
            (output_query): BertOutput(
              (dense): Linear(in_features=3072, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
          )
        )
      )
    )
    (cls): None
  )
  (opt_model): OPTForCausalLM(
    ......
  )
  (opt_proj): Linear(in_features=768, out_features=2560, bias=True)
)