说明1:该专栏没有规划有多少个系列文章,而是根据每次文章内容难易程度、文章较佳阅读时长决定最终文章篇幅。
说明2:该专栏是对 huggingface 中多个模型的源码解析,源项目地址:
huggingface/transformersgithub.com
Transformers 文档:
Transformers - transformers 2.1.0 documentationhuggingface.co
一、源码阅读的三个原则和一个能力
- 三个原则
- IPO原则:完成一件事情必须要经历输入(Input)—处理(Process)—输出(Output)三个步骤,在IPO原则基础上,再对将要完成的目标添加需要的辅助文件。
- AD原则:根据上述 IPO 原则,找到输入,跳过处理,返回输出,画出概要(Abstract);完成概要设计后,跳到详细(Detail)代码中。若在拆解代码过程时,碰到其他 IPO 过程,依样画葫芦即可。
- BO原则:在 IPO 原则和 AD 原则的基础上,读懂代码,思考代码逻辑,获得不同代码模块的功能。以 Business-Oriented (业务导向),在业务目标下,将各个独立的、不同的功能组装拼接用以支撑业务。
- 一个能力
- 想象能力--不是天马行空般地胡乱猜想,是在原作者设计的基础上,深入思考:
- 作者为何要如此设计?
- 实现这个功能的基本设计流程思路是怎样的?
- 按照功能需要,我们自己设计一个流程,往往我们仅会考虑正常流程,而忽略掉异常流程。完成设计后,对比源码,看看作者在可能出现的问题上又做了哪些针对性的预防措施?
二、各个类之间的关系以及解析策略
转换过程中主要涉及以上十六个类,其中 nn.Module 是 PyTorch 提供的,暂且不谈。剩下的十五个类,我们以 BertModel
- 中心:BertModel
- 向上:
- 两个通用工具类:PretrainedConfig、PreTrainedModel
- BertModel 的直接父类:BertPreTrainedModel,以及它需要关联的 BertConfig 和 BertLayerNorm
- 向下:
- BertModel 关联的三个类:BertEmbeddings、BertEncoder、BertPooler
- BertEncoder 关联的一个类:BertLayer
- BertLayer 关联的三个类:BertAttention、BertIntermediate、BertOutput
- BertAttention 关联的两个类:BertSelfAttention、BertSelfOutput
三、关键函数之间调用关系的时序图
仅给出关键函数之间的调用关系,从而使得逻辑更加清晰,中间存在的其他函数之间的调用关系,可以从源码中获取。
四、convert_bert_pytorch_checkpoint_to_original_tf.py 文件解析
我们从 convert_bert_pytorch_checkpoint_to_original_tf.py 文件中的 main() 函数开始,结合 IPO 和 AD 原则来阅读源码。
main() 函数有参数解析、模型加载、模型转换三个功能(AD原则中 A(Abstract)的体现)。
1. 参数解析
- 参数:模型名称、pytorch版模型的缓存地址、pytorch 版模型的存放地址、tf 版模型的输出地址
parser
【思考1】:在没有源码的情况下,实现 pytorch->tf 模型的转换,那么我们首先会有哪些什么想法呢?
1. 根据 IPO 原则,会先确定输入、输出,暂时忽略处理部分。
输入:pytorch 版模型
输出:tensorflow 版模型
2. 根据 AD 原则,对输入、输出环节进行细化。
输入:
离线下载模型,并存放到具体路径下(pytorch_model_path),需要调用模型的时候,直接加载。
在线下载模型。如果是第一次下载,并在给定的目录中没有找到需要调用的模型,那么,在线下载,下载完成之后并存放到指定路径下(cache_dir)。
输出:
一个包含 tf 模型的具体路径(tf_cache_dir),同时输出应该包含4个文件(tf版模型的特殊之处)。
【源码1】:在完成上面的思考之后,我们瞅瞅看实际源码中的 IPO 过程是怎样的?
目标:将 pytorch 版的模型转成 tf 版的模型
输入:pytorch 版的模型名称【model_name】、pytorch 版模型的存放路径【pytorch_model_path】、pytorch 版模型的缓存路径【cache_dir】
处理:将 pytorch 版的模型 转换成 tf 版的模型
输出:tf 版的模型输出路径【tf_cache_dir】
【校正1】:遗漏了 model_name
【小贴士】
- argparse 核心内容(英文):
https://docs.python.org/3.6/howto/argparse.htmldocs.python.org
- argparse 核心内容(中文):
李小伟:参数解析 argparse 核心内容zhuanlan.zhihu.com
2. 模型加载
- 参数:预训练模型名称或路径、pytorch版模型、pytorch版模型的输入地址
model = BertModel.from_pretrained(
pretrained_model_name_or_path=args.model_name,
state_dict=torch.load(args.pytorch_model_path),
cache_dir=args.cache_dir)
3. 将 pytorch 版的模型 转换成 tf 版的模型
- 参数:模型、checkpoint的输出地址、预训练模型名称
convert_pytorch_checkpoint_to_tf