文章目录
本文涉及知识点
- Hugging Face快速入门
- Pytorch中DataLoader和Dataset的基本用法
本文内容
这是Kaggle上NLP的一个入门题目(链接),任务是对文本进行二分类。内容描述:人们会在Twitter上发布一些内容,这些内容有些是灾难事件,例如“白宫着火了,火焰很大”,这就是一个灾难事件。而有一些虽然也带了相关词汇,却不是灾难事件,例如:”天上那朵云好像燃烧的火焰。“。所以本项目的任务就是区分这两种情况。
数据集可以到Kaggle上下载(链接),或者使用百度网盘下载(链接)
最终可以将你的预测结果上传到Kaggle上查看分数(链接)。
你可以在Github上找到本文的源码(链接)。你也可以直接使用Google Colab来运行代码(Open In Google Colab)
环境配置
本项目使用库版本如下
导入本文要使用的所有依赖包:
全局配置
数据处理
加载数据集
请先下载数据集,并解压到dataset
目录下,其中会有train.csv、test.csv和sample_submission.csv三个文件。
使用pandas来加载训练数据,对于训练数据,我们只需要text和target两行:
加载成功后,来看一下内容:
text | target | |
0 | Our Deeds are the Reason of this #earthquake M... | 1 |
1 | Forest fire near La Ronge Sask. Canada | 1 |
2 | All residents asked to 'shelter in place' are ... | 1 |
3 | 13,000 people receive #wildfires evacuation or... | 1 |
4 | Just got sent this photo from Ruby #Alaska as ... | 1 |
... | ... | ... |
7608 | Two giant cranes holding a bridge collapse int... | 1 |
7609 | @aria_ahrary @TheTawniest The out of control w... | 1 |
7610 | M1.94 [01:04 UTC]?5km S of Volcano Hawaii. htt... | 1 |
7611 | Police investigating after an e-bike collided ... | 1 |
7612 | The Latest: More Homes Razed by Northern Calif... | 1 |
text就是推文,target就是该推文是否是在描述一个灾难事件,1:是,0:否。
Dataset and Dataloader
我们将训练数据按比例随机分为训练集和验证集:
加载好数据集后,我们就可以开始构建Dataset了,我们这里Dataset就是返回推文和其target:
我们来打印看一下;
构造好Dataset后,就可以来构造Dataloader了。在构造Dataloader前,我们需要先定义好分词器:
我们来尝试使用一下分词器:
可以正常运行。其中101表示“开始”([CLS]
),102表示句子结束([SEP]
)。
我们接着构造我们的Dataloader。我们需要定义一下collate_fn,在其中完成对句子进行编码、填充、组装batch等动作:
我们来看一眼train_loader的数据:
构建模型
训练模型
接下来开始正式训练模型,首先定义出损失函数和优化器。因为是二分类问题,用Binary Cross Entropy就行:
这个学习率是我测试出来的,之前用的
3e-4
,发现怎么都不收敛。看来学习率确实很重要。
定义一个验证方法,获取到验证集的精准率和loss。
开始训练:
模型使用
加载最好的模型,然后按照Kaggle的要求组装csv文件。
构造测试集的dataloader。测试集是不包含target的。
将测试数据送入模型,得到结果,最后组装成Kaggle要求数据结构:
拿着结果去Kaggle上试一下吧,看看你能得多少分。我这边跑了10个Epoch,最终得到了0.83573的分数,还行。