BGE-M3模型论文解读(基于MindNLP实现)
BGE M3模型,全称为M3-Embedding,主要通过自我知识蒸馏(Self-Knowledge Distillation)实现了在多语言、多功能和多粒度文本嵌入的显著进步。此模型支持超过100种工作语言的语义检索,能同时完成密集检索(Dense Retrieval)、多向量检索(Multi-Vector Retrieval)和稀疏检索(Sparse Retrieval)三种常见的检索功能,并能处理从短句到长文档(最多8192个令牌)不同粒度的输入。
创新点
BGE M3-Embedding模型的核心创新体现在以下几个方面:
- **多语言支持:**该模型支持超过100种语言的检索,不仅能够在同一语言内部进行检索,还能实现跨语言的检索功能。这种多语言支持对于多语言应用场景中的信息检索具有极大的实际价值,能够在全球化的语境中处理复杂的语言需求。
- **多功能检索:**M3-Embedding模型具有三种主要检索功能:
- **密集检索(Dense Retrieval):**通过计算查询和文档的嵌入向量相似度来进行检索。
- 多向量检索(Multi-Vector Retrieval):基于多个嵌入向量的细粒度交互计算查询和文档的相关性。
- **稀疏检索(Sparse Retrieval):**将输出嵌入用于估计每个词的重要性,从而实现稀疏检索。
M3-Embedding模型不仅能够单独执行这些检索功能,还能将这些功能组合起来,通过加权求和的方式整合不同检索方式的相关性分数,实现更高精度的检索结果。
- 多粒度处理:该模型能够处理从短句到长达8192个令牌的长文档输入,具备极强的适应性。相较于大多数只能处理短文本的模型,M3-Embedding显著提升了对长文档的处理能力,适用于长篇文档的语义匹配和信息检索。
- **自我知识蒸馏:**该模型的另一项关键创新是引入了自我知识蒸馏技术。自我知识蒸馏的原理是将多种检索功能生成的相关性分数作为教师信号,通过知识蒸馏来增强模型的学习效果。这种方法利用不同检索功能的异质性进行集成,类似于集成学习的思想,从而提升了模型的整体性能。
- **优化批处理策略:**为了支持大规模的多语言数据训练,M3-Embedding模型优化了批处理策略,采用了大批量处理和高训练吞吐量,极大地提高了训练效率。通过将训练数据按序列长度进行分组,减少了填充操作,从而更有效地利用GPU资源。这一策略使得模型在处理长文档时能够显著扩大批量大小,进一步提高嵌入的区分度。
- **高质量数据策划:**M3-Embedding模型的训练数据来源广泛,包含了三类数据:
- 大规模的无监督数据,从多语言语料库中提取具有丰富语义结构的文本对;
- 高质量的监督数据,覆盖了多种语言和任务;
- 为长文档检索任务生成的合成数据,通过生成式预训练模型(如GPT-3.5)自动生成问题和对应的文档对,以增强模型对长文档的理解能力。
数据集上的评价指标得分
在多项权威的多语言和跨语言信息检索基准测试中,BGE M3-Embedding模型的表现超越了多个现有模型,达到了新的性能标准。以下是该模型在MIRACL和MKQA等基准数据集上的表现:
- MIRACL数据集:该数据集包含18种语言的检索任务。M3-Embedding在密集检索(Dense Retrieval)、稀疏检索(Sparse Retrieval)和多向量检索(Multi-Vector Retrieval)三种检索任务中均取得了优异的成绩。例如,在nDCG@10的评估指标下,M3-Embedding的平均得分为71.5,显著优于其他方法。
- MKQA数据集:在25种语言的跨语言检索任务中,M3-Embedding同样表现出色。在Recall@100评估指标下,M3-Embedding的得分为75.5,远超大多数现有模型,尤其在低资源语言(如阿拉伯语、希伯来语等)中表现尤为突出。
- **MLDR和NarrativeQA:**M3-Embedding还在长文档检索任务中表现出色。在MLDR和NarrativeQA等长文档数据集上,M3-Embedding通过多功能检索的结合,实现了对长达8192个令牌文档的高效处理,进一步验证了其在长文档场景下的优越性。
相比其他工作的优势
与其他嵌入模型相比,BGE M3-Embedding具有以下显著优势:
- **多语言和跨语言能力:**M3-Embedding支持超过100种语言,并且能够进行高效的跨语言检索。相比于只支持少数几种语言的传统模型,M3-Embedding在全球化应用中具有更强的通用性和适应性。
- **多功能性:**M3-Embedding不仅能进行密集检索,还能同时支持稀疏检索和多向量检索。这种多功能的嵌入模型为信息检索系统提供了极大的灵活性,能够根据不同场景需求选择最合适的检索策略,并且支持多种检索功能的组合使用,从而在复杂任务中获得更高的检索精度。
- **自我知识蒸馏技术:**自我知识蒸馏技术使得M3-Embedding能够有效整合多种检索功能的优势,提升模型的整体性能。相较于传统的知识蒸馏方法,自我知识蒸馏无需依赖外部教师模型,能够更高效地训练多功能检索模型,降低了训练复杂度并提高了效果。
- **长文档处理能力:**M3-Embedding显著提升了对长文档的处理能力,支持最多8192个令牌的输入。这使得该模型在长文档检索和语义匹配任务中具有显著优势,尤其在法律、医学等需要处理大篇幅文档的领域,M3-Embedding的应用潜力巨大。
- **训练效率:**通过优化批处理策略,M3-Embedding能够在大规模多语言数据集上高效训练,同时保持高吞吐量和大批量处理能力。这不仅提高了模型的训练速度,也增强了模型对不同语言和长文档的处理能力。
针对MindNLP实现的评估
对于MindNLP的bge_m3实现,我们使用了官方相同的MKQA基准来评估bge_m3模型的跨语言检索性能。MKQA基准包含了25种非英语语言的查询,每个查询任务都需要从英文维基百科语料库中检索包含答案的段落。为了进行这一实验,我们利用了BEIR提供的精心处理过的语料库BeIR/nq,并采用了密集检索(Dense Retrieval)方式。主要的评价指标包括Recall@100和Recall@20。整个测试过程均与官方相同。
由于资源限制,此次测试仅覆盖150000条数据,导致评估结果低于全量测试。根据测试结果,MindNLP实现与官方实现的误差在1%以内。
测试结果如下:
语言 | BGE-M3 Recall@100 | BGE-M3 Recall@20 | MindNLP Recall@100 | MindNLP Recall@20 | 差值% Recall@100 | 差值% Recall@20 |
ar | 49.54 | 39.52 | 49.54 | 39.54 | 0.00% | 0.02% |
da | 52.15 | 43.53 | 52.15 | 43.59 | 0.00% | 0.06% |
de | 52.17 | 42.85 | 52.15 | 42.83 | -0.02% | -0.02% |
es | 52.47 | 42.79 | 52.44 | 42.79 | -0.03% | 0.00% |
fi | 51.31 | 41.94 | 51.31 | 41.92 | 0.00% | -0.05% |
fr | 52.23 | 43.21 | 52.23 | 43.21 | 0.00% | 0.00% |
he | 49.77 | 39.72 | 49.77 | 39.76 | 0.00% | 0.04% |
hu | 51.02 | 41.83 | 51.02 | 41.83 | 0.00% | 0.00% |
it | 51.91 | 42.67 | 51.88 | 42.73 | -0.03% | 0.06% |
ja | 51.35 | 41.88 | 51.38 | 41.88 | 0.03% | 0.00% |
km | 47.64 | 37.60 | 47.64 | 37.59 | 0.00% | -0.01% |
ko | 49.36 | 39.10 | 49.36 | 39.11 | 0.00% | 0.01% |
ms | 51.90 | 43.30 | 51.93 | 43.28 | 0.03% | -0.02% |
nl | 52.20 | 43.56 | 52.21 | 43.54 | 0.01% | -0.02% |
no | 51.99 | 43.38 | 52.02 | 43.39 | 0.03% | 0.01% |
pl | 51.52 | 42.48 | 51.49 | 42.48 | -0.03% | 0.00% |
pt | 51.90 | 42.85 | 51.91 | 42.83 | 0.01% | -0.02% |
ru | 52.00 | 42.73 | 52.06 | 42.73 | 0.06% | 0.00% |
sv | 51.96 | 42.89 | 51.97 | 42.89 | 0.01% | 0.00% |
th | 51.35 | 42.67 | 51.35 | 42.65 | 0.00% | -0.02% |
tr | 51.37 | 42.26 | 51.38 | 42.29 | 0.01% | 0.03% |
vi | 51.94 | 42.54 | 51.99 | 42.50 | 0.05% | -0.04% |
如结果所示,MindNLP的实现在精度上与官方实现非常接近,误差均低于1%。
推理代码
推理代码的实现参考FlagEmbedding/C_MTEB/MKQA/dense_retrieval at master · FlagOpen/FlagEmbedding (github.com)
完整推理代码BGE-M3-MindNLP-TEST
为适配MindNLP需要修改step0-generate_embedding.py
文件,代码如下:
import os
import sys
import faiss
import datasets
import numpy as np
from tqdm import tqdm
from pprint import pprint
from dataclasses import dataclass, field
from transformers import HfArgumentParser
from mindspore import ops
from mindnlp.transformers import AutoModel, AutoTokenizer
from mindspore import context
import logging
sys.path.append("..")
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def norm(tensor):
"""对生成的嵌入进行归一化"""
norm_ = ops.L2Normalize(axis=-1, epsilon=1e-12)
return norm_(tensor)
@dataclass
class ModelArgs:
encoder: str = field(
default="liuyanyi/bge-m3-hf", # 替换为MindSpore模型的路径
metadata={'help': 'Name or path of encoder'}
)
fp16: bool = field(
default=True,
metadata={'help': 'Use fp16 in inference?'}
)
pooling_method: str = field(
default='cls',
metadata={'help': "Pooling method. Avaliable methods: 'cls', 'mean'"}
)
normalize_embeddings: bool = field(
default=True,
metadata={'help': "Normalize embeddings or not"}
)
@dataclass
class EvalArgs:
index_save_dir: str = field(
default='./corpus-index',
metadata={
'help': 'Dir to save index. Corpus index will be saved to `index_save_dir/{encoder_name}/index`. Corpus ids will be saved to `index_save_dir/{encoder_name}/docid` .'}
)
max_passage_length: int = field(
default=512,
metadata={'help': 'Max passage length.'}
)
batch_size: int = field(
default=82,
metadata={'help': 'Inference batch size.'}
)
overwrite: bool = field(
default=False,
metadata={'help': 'Whether to overwrite embedding'}
)
def get_model(model_args: ModelArgs):
tokenizer = AutoTokenizer.from_pretrained(model_args.encoder)
model = AutoModel.from_pretrained(model_args.encoder)
# model.jit()
return model, tokenizer
from tqdm import tqdm
import itertools
def parse_corpus(corpus: datasets.Dataset, max_samples=None):
corpus_list = []
if max_samples is not None:
iterator = itertools.islice(corpus, max_samples)
else:
iterator = corpus
# 遍历语料数据并进行处理
for data in tqdm(iterator, desc="Generating corpus"):
_id = str(data['_id'])
content = f"{data['title']}\n{data['text']}".lower()
content = normalize(content)
corpus_list.append({"id": _id, "content": content})
# 将处理后的数据转换为 datasets.Dataset 格式
processed_corpus = datasets.Dataset.from_list(corpus_list)
return processed_corpus
import time
def bgeModel(model, inputs):
outputs = model(**inputs, return_dict=True)
return outputs
def generate_embeddings(model, tokenizer, texts, max_passage_length=512, batch_size=1000):
all_embeddings = []
# 初始化 tqdm 进度条
pbar = tqdm(total=len(texts), desc="Generating Embeddings")
for i in range(0, len(texts), batch_size):
batch_texts = texts[i:i + batch_size]
# 开始计时
start_time = time.time()
# 生成嵌入
inputs = tokenizer(batch_texts, return_tensors="ms", padding=True, truncation=True,
max_length=max_passage_length)
# outputs = model(**inputs, return_dict=True)
outputs = bgeModel(model, inputs)
dense_output = outputs.last_hidden_state[:, 0, :] # 使用 [CLS] token 的输出
dense_output = norm(dense_output) # 归一化嵌入
all_embeddings.append(dense_output.asnumpy()) # 转换为 numpy 以便后续使用
end_time = time.time()
batch_time = end_time - start_time
# 更新进度条
pbar.update(len(batch_texts))
pbar.set_postfix(batch_time=f"{batch_time:.4f} s", batch_size=f"{len(batch_texts)}")
pbar.close() # 完成后关闭进度条
return np.vstack(all_embeddings)
def generate_index(model, tokenizer, corpus: datasets.Dataset, max_passage_length: int = 512, batch_size: int = 512):
"""生成FAISS索引"""
corpus_embeddings = generate_embeddings(model, tokenizer, corpus["content"], max_passage_length, batch_size)
dim = corpus_embeddings.shape[-1]
faiss_index = faiss.index_factory(dim, "Flat", faiss.METRIC_INNER_PRODUCT)
corpus_embeddings = corpus_embeddings.astype(np.float32)
faiss_index.train(corpus_embeddings)
faiss_index.add(corpus_embeddings)
return faiss_index, list(corpus["id"])
def save_result(index: faiss.Index, docid: list, index_save_dir: str):
"""保存索引和文档ID"""
docid_save_path = os.path.join(index_save_dir, 'docid')
index_save_path = os.path.join(index_save_dir, 'index')
with open(docid_save_path, 'w', encoding='utf-8') as f:
for _id in docid:
f.write(str(_id) + '\n')
faiss.write_index(index, index_save_path)
def main():
"""主流程"""
parser = HfArgumentParser([ModelArgs, EvalArgs])
model_args, eval_args = parser.parse_args_into_dataclasses()
model_args: ModelArgs
eval_args: EvalArgs
if model_args.encoder[-1] == '/':
model_args.encoder = model_args.encoder[:-1]
model, tokenizer = get_model(model_args=model_args)
encoder = model_args.encoder
if os.path.basename(encoder).startswith('checkpoint-'):
encoder = os.path.dirname(encoder) + '_' + os.path.basename(encoder)
print("==================================================")
print("Start generating embedding with model:")
print(model_args.encoder)
index_save_dir = os.path.join(eval_args.index_save_dir, os.path.basename(encoder))
print("index_save_dir", index_save_dir)
if not os.path.exists(index_save_dir):
os.makedirs(index_save_dir)
if os.path.exists(os.path.join(index_save_dir, 'index')) and not eval_args.overwrite:
print(f'Embedding already exists. Skip...')
return
# 加载数据集
corpus = \
datasets.load_dataset("/home/ma-user/work/workplace/mindnlp/FlagEmbedding/C_MTEB/MKQA/dense_retrieval/nq",
'corpus',
trust_remote_code=True)['corpus']
corpus = parse_corpus(corpus=corpus)
# 生成索引
index, docid = generate_index(
model=model,
tokenizer=tokenizer,
corpus=corpus,
max_passage_length=eval_args.max_passage_length,
batch_size=eval_args.batch_size
)
# 保存结果
save_result(index, docid, index_save_dir)
print("==================================================")
print("Finish generating embeddings with following model:")
pprint(model_args.encoder)
if __name__ == "__main__":
main()
总结
BGE-M3模型凭借M3嵌入模型凭借其在多语言支持、多功能性和处理不同数据粒度的能力上的优异表现,为文本嵌入领域带来了新的突破。
建议各位开发者利用MindNLP等工具来加载并复现该模型的实验成果。MindNLP提供了一套与PyTorch风格一致的简洁接口,加载和评估预训练模型非常直接和高效。