pytorch-textclassification
pytorch-textclassification是一个以pytorch和transformers为基础,专注于文本分类的轻量级自然语言处理工具包。支持中文长文本、短文本的多类分类和多标签分类。
目录
- 数据
- 使用方式
- paper
- 参考
pytorch-textclassification: https://github.com/yongzhuo/Pytorch-NLU
数据
数据来源
所有数据集均来源于网络,只做整理供大家提取方便,如果有侵权等问题,请及时联系删除。
- baidu_event_extract_2020, 项目以 2020语言与智能技术竞赛:事件抽取任务中的数据作为多分类标签的样例数据,借助多标签分类模型来解决, 共13456个样本, 65个类别;
- AAPD-dataset, 数据集出现在论文-SGM: Sequence Generation Model for Multi-label Classification, 英文多标签分类语料, 共55840样本, 54个类别;
- toutiao-news, 今日头条新闻标题, 多标签分类语料, 约300w-语料, 1000+类别;
- unknow-data, 来源未知, 多标签分类语料, 约22339语料, 7个类别;
- SMP2018中文人机对话技术评测(ECDT), SMP2018 中文人机对话技术评测(SMP2018-ECDT)比赛语料, 短文本意图识别语料, 多类分类, 共3069样本, 31个类别;
- 文本分类语料库(复旦)语料, 复旦大学计算机信息与技术系国际数据库中心自然语言处理小组提供的新闻语料, 多类分类语料, 共9804篇文档,分为20个类别。
- MiningZhiDaoQACorpus, 中国科学院软件研究所刘焕勇整理的问答语料, 百度知道问答语料, 可以把领域当作类别, 多类分类语料, 100w+样本, 共17个类别;
- THUCNEWS, 清华大学自然语言处理实验室整理的语料, 新浪新闻RSS订阅频道2005-2011年间的历史数据筛选, 多类分类语料, 74w新闻文档, 14个类别;
- IFLYTEK, 科大讯飞开源的长文本分类语料, APP应用描述的标注数据,包含和日常生活相关的各类应用主题, 链接为CLUE, 共17333样例, 119个类别;
- TNEWS, 今日头条提供的中文新闻标题分类语料, 数据集来自今日头条的新闻版块, 链接为CLUE, 共73360样例, 15个类别;
数据格式
1. 文本分类 (txt格式, 每行为一个json):
1.1 多类分类格式:
{"text": "人站在地球上为什么没有头朝下的感觉", "label": "教育"}
{"text": "我的小baby", "label": "娱乐"}
{"text": "请问这起交通事故是谁的责任居多小车和摩托车发生事故在无红绿灯", "label": "娱乐"}
1.2 多标签分类格式:
{"label": "3|myz|5", "text": "课堂搞东西,没认真听"}
{"label": "3|myz|2", "text": "测验90-94.A-"}
{"label": "3|myz|2", "text": "长江作业未交"}
使用方式
更多样例sample详情见test/tc目录
文本分类(TC), Text-Classification
# !/usr/bin/python
# -*- coding: utf-8 -*-
# !/usr/bin/python
# -*- coding: utf-8 -*-
# @time : 2021/2/23 21:34
# @author : Mo
# @function: 多标签分类, 根据label是否有|myz|分隔符判断是多类分类, 还是多标签分类
# 适配linux
import platform
import json
import sys
import os
path_root = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
sys.path.append(os.path.join(path_root, "pytorch_textclassification"))
print(path_root)
# 分类下的引入, pytorch_textclassification
from tcTools import get_current_time
from tcRun import TextClassification
from tcConfig import model_config
evaluate_steps = 320 # 评估步数
save_steps = 320 # 存储步数
# pytorch预训练模型目录, 必填
pretrained_model_name_or_path = "bert-base-chinese"
# 训练-验证语料地址, 可以只输入训练地址
path_corpus = os.path.join(path_root, "corpus", "text_classification", "school")
path_train = os.path.join(path_corpus, "train.json")
path_dev = os.path.join(path_corpus, "dev.json")
if __name__ == "__main__":
model_config["evaluate_steps"] = evaluate_steps # 评估步数
model_config["save_steps"] = save_steps # 存储步数
model_config["path_train"] = path_train # 训练模语料, 必须
model_config["path_dev"] = path_dev # 验证语料, 可为None
model_config["path_tet"] = None # 测试语料, 可为None
# 损失函数类型,
# multi-class: 可选 None(BCE), BCE, BCE_LOGITS, MSE, FOCAL_LOSS, DICE_LOSS, LABEL_SMOOTH
# multi-label: SOFT_MARGIN_LOSS, PRIOR_MARGIN_LOSS, FOCAL_LOSS, CIRCLE_LOSS, DICE_LOSS等
model_config["path_tet"] = "FOCAL_LOSS"
os.environ["CUDA_VISIBLE_DEVICES"] = str(model_config["CUDA_VISIBLE_DEVICES"])
model_config["pretrained_model_name_or_path"] = pretrained_model_name_or_path
model_config["model_save_path"] = "../output/text_classification/model_{}".format(model_type[idx])
model_config["model_type"] = "BERT"
# main
lc = TextClassification(model_config)
lc.process()
lc.train()
paper
FastText: Bag of Tricks for Efficient Text Classification
TextCNN: Convolutional Neural Networks for Sentence Classification
charCNN-kim: Character-Aware Neural Language Models
charCNN-zhang: Character-level Convolutional Networks for Text Classification
TextRNN: Recurrent Neural Network for Text Classification with Multi-Task Learning
RCNN: Recurrent Convolutional Neural Networks for Text Classification
DCNN: A Convolutional Neural Network for Modelling Sentences
DPCNN: Deep Pyramid Convolutional Neural Networks for Text Categorization
VDCNN: Very Deep Convolutional Networks
CRNN: A C-LSTM Neural Network for Text Classification
DeepMoji: Using millions of emojio ccurrences to learn any-domain represent ations for detecting sentiment, emotion and sarcasm
SelfAttention: Attention Is All You Need
HAN: Hierarchical Attention Networks for Document Classification
CapsuleNet: Dynamic Routing Between Capsules
Transformer(encode or decode): Attention Is All You Need
Bert: BERT: Pre-trainingofDeepBidirectionalTransformersfor LanguageUnderstanding
Xlnet: XLNet: Generalized Autoregressive Pretraining for Language Understanding
Albert: ALBERT: A LITE BERT FOR SELF-SUPERVISED LEARNING OF LANGUAGE REPRESENTATIONS
RoBERTa: RoBERTa: A Robustly Optimized BERT Pretraining Approach
ELECTRA: ELECTRA: Pre-training Text Encoders as Discriminators Rather Than Generators
TextGCN: Graph Convolutional Networks for Text Classification
参考
This library is inspired by and references following frameworks and papers.
- keras与tensorflow版本对应: https://docs.floydhub.com/guides/environments/
- bert4keras: https://github.com/bojone/bert4keras
- Kashgari: https://github.com/BrikerMan/Kashgari
- fastNLP: https://github.com/fastnlp/fastNLP
- HanLP: https://github.com/hankcs/HanLP
- FGM: 【炼丹技巧】功守道:NLP中的对抗训练 + PyTorch实现
- pytorch-loss: pytorch-loss
- PriorLoss: 通过互信息思想来缓解类别不平衡问题
- CircleLoss: 将“softmax+交叉熵”推广到多标签分类问题
- FocalLoss: Focal Loss for Dense Object Detection
- scikit-learn: https://github.com/scikit-learn/scikit-learn
- tqdm: https://github.com/tqdm/tqdm
Reference
For citing this work, you can refer to the present GitHub project. For example, with BibTeX:
@software{Pytorch-NLU,
url = {https://github.com/yongzhuo/Pytorch-NLU},
author = {Yongzhuo Mo},
title = {Pytorch-NLU},
year = {2021}
*希望对你有所帮助!
实验
corpus==unknow-data, pretrain-model==ernie-tiny, batch=32, lr=1e-5, epoch=21
总结
micro-微平均
precision recall f1-score support
micro_avg 0.7920 0.7189 0.7537 466 MARGIN_LOSS
micro_avg 0.6706 0.8519 0.7505 466 PRIOR-MARGIN_LOSS
micro_avg 0.8258 0.6309 0.7153 466 FOCAL_LOSS【0.5, 2】
micro_avg 0.7890 0.7382 0.7627 466 CIRCLE_LOSS
micro_avg 0.7612 0.7661 0.7636 466 DICE_LOSS【直接学习F1?】
micro_avg 0.8062 0.7232 0.7624 466 BCE
micro_avg 0.7825 0.7103 0.7447 466 BCE-Logits
micro_avg 0.7899 0.7017 0.7432 466 BCE-smooth
micro_avg 0.7235 0.8197 0.7686 466 FOCAL_LOSS【0.5, 2】 + PRIOR-MARGIN_LOSS / 2
macro-宏平均
precision recall f1-score support
macro_avg 0.6198 0.5338 0.5641 466 MARGIN_LOSS
macro_avg 0.5103 0.7200 0.5793 466 PRIOR-MARGIN_LOSS
macro_avg 0.7655 0.4973 0.5721 466 FOCAL_LOSS【0.5, 2】
macro_avg 0.6275 0.5235 0.5627 466 CIRCLE_LOSS
macro_avg 0.4287 0.3918 0.4025 466 DICE_LOSS【直接学习F1?】
macro_avg 0.6978 0.5158 0.5828 466 BCE
macro_avg 0.6046 0.5123 0.5433 466 BCE-Logits
macro_avg 0.6963 0.5012 0.5721 466 BCE-smooth
macro_avg 0.6033 0.6809 0.6369 466 FOCAL_LOSS【0.5, 2】 + PRIOR-MARGIN_LOSS / 2
1. batch=32, loss=MARGIN_LOSS, lr=1e-5, epoch=21, 【精确率高些】
precision recall f1-score support
3 0.8102 0.7919 0.8009 221
2 0.8030 0.8030 0.8030 132
1 0.7333 0.4925 0.5893 67
6 0.7143 0.5000 0.5882 10
5 0.7778 0.4828 0.5957 29
0 0.0000 0.0000 0.0000 4
4 0.5000 0.6667 0.5714 3
micro_avg 0.7920 0.7189 0.7537 466
macro_avg 0.6198 0.5338 0.5641 466
weighted_avg 0.7841 0.7189 0.7454 466
2. batch=32, loss=PRIOR-MARGIN_LOSS, lr=1e-5, epoch=21, 【召回率高些】
precision recall f1-score support
3 0.7279 0.8959 0.8032 221
2 0.7039 0.9545 0.8103 132
1 0.5897 0.6866 0.6345 67
6 0.3333 0.5000 0.4000 10
5 0.6296 0.5862 0.6071 29
0 0.1875 0.7500 0.3000 4
4 0.4000 0.6667 0.5000 3
micro_avg 0.6706 0.8519 0.7505 466
macro_avg 0.5103 0.7200 0.5793 466
weighted_avg 0.6799 0.8519 0.7538 466
3. batch=32, loss=FOCAL_LOSS【(0.5, 2)】, lr=1e-5, epoch=21, 【精确率超级高, 0.25效果会变差】
precision recall f1-score support
3 0.8482 0.7330 0.7864 221
2 0.8349 0.6894 0.7552 132
1 0.7586 0.3284 0.4583 67
6 0.6667 0.4000 0.5000 10
5 0.7500 0.4138 0.5333 29
0 1.0000 0.2500 0.4000 4
4 0.5000 0.6667 0.5714 3
micro_avg 0.8258 0.6309 0.7153 466
macro_avg 0.7655 0.4973 0.5721 466
weighted_avg 0.8206 0.6309 0.7038 466
4. batch=32, loss=CIRCLE_LOSS【, lr=1e-5, epoch=21, 【效果很好, 精确率召回率相对比较均衡】
precision recall f1-score support
3 0.8125 0.8235 0.8180 221
2 0.7914 0.8333 0.8118 132
1 0.7333 0.4925 0.5893 67
6 0.6667 0.4000 0.5000 10
5 0.7222 0.4483 0.5532 29
0 0.0000 0.0000 0.0000 4
4 0.6667 0.6667 0.6667 3
micro_avg 0.7890 0.7382 0.7627 466
macro_avg 0.6275 0.5235 0.5627 466
weighted_avg 0.7785 0.7382 0.7521 466
5. batch=32, loss=DICE_LOSS, lr=1e-5, epoch=21, 【F1指标比较高, 少样本数据学不到, 不稳定】
precision recall f1-score support
3 0.7714 0.8552 0.8112 221
2 0.7727 0.9015 0.8322 132
1 0.7347 0.5373 0.6207 67
6 0.0000 0.0000 0.0000 10
5 0.7222 0.4483 0.5532 29
0 0.0000 0.0000 0.0000 4
4 0.0000 0.0000 0.0000 3
micro_avg 0.7612 0.7661 0.7636 466
macro_avg 0.4287 0.3918 0.4025 466
weighted_avg 0.7353 0.7661 0.7441 466
6. batch=32, loss=BCE, lr=1e-5, epoch=21, 【普通的居然意外的好呢】
precision recall f1-score support
3 0.8136 0.8100 0.8118 221
2 0.8029 0.8333 0.8178 132
1 0.8235 0.4179 0.5545 67
6 0.6667 0.4000 0.5000 10
5 0.7778 0.4828 0.5957 29
0 0.0000 0.0000 0.0000 4
4 1.0000 0.6667 0.8000 3
micro_avg 0.8062 0.7232 0.7624 466
macro_avg 0.6978 0.5158 0.5828 466
weighted_avg 0.8009 0.7232 0.7493 466
7. batch=32, loss=BCE_LOGITS, lr=1e-5, epoch=21, 【torch.nn.BCEWithLogitsLoss】
precision recall f1-score support
3 0.7973 0.8009 0.7991 221
2 0.8000 0.7879 0.7939 132
1 0.7317 0.4478 0.5556 67
6 0.6667 0.4000 0.5000 10
5 0.7368 0.4828 0.5833 29
0 0.0000 0.0000 0.0000 4
4 0.5000 0.6667 0.5714 3
micro_avg 0.7825 0.7103 0.7447 466
macro_avg 0.6046 0.5123 0.5433 466
weighted_avg 0.7733 0.7103 0.7344 466
8. batch=32, loss=LABEL_SMOOTH, lr=1e-5, epoch=21, 【BCE-Label-smooth】
precision recall f1-score support
3 0.7945 0.7873 0.7909 221
2 0.8120 0.8182 0.8151 132
1 0.7027 0.3881 0.5000 67
6 0.8000 0.4000 0.5333 10
5 0.7647 0.4483 0.5652 29
0 0.0000 0.0000 0.0000 4
4 1.0000 0.6667 0.8000 3
micro_avg 0.7899 0.7017 0.7432 466
macro_avg 0.6963 0.5012 0.5721 466
weighted_avg 0.7790 0.7017 0.7296 466
9. batch=32, loss=FOCAL_LOSS + PRIOR-MARGIN_LOSS, lr=1e-5, epoch=21, 【这两个Loss混合,宏平均(macro-avg)效果居然意外的好呢!】
【1/2】
precision recall f1-score support
3 0.7640 0.8643 0.8110 221
2 0.7205 0.8788 0.7918 132
1 0.6620 0.7015 0.6812 67
6 0.4167 0.5000 0.4545 10
5 0.7600 0.6552 0.7037 29
0 0.4000 0.5000 0.4444 4
4 0.5000 0.6667 0.5714 3
micro_avg 0.7235 0.8197 0.7686 466
macro_avg 0.6033 0.6809 0.6369 466
weighted_avg 0.7245 0.8197 0.7679 466
【调和平均数】
precision recall f1-score support
3 0.8474 0.7285 0.7835 221
2 0.8304 0.7045 0.7623 132
1 0.8182 0.4030 0.5400 67
6 0.8000 0.4000 0.5333 10
5 0.7143 0.3448 0.4651 29
0 1.0000 0.2500 0.4000 4
4 0.6667 0.6667 0.6667 3
micro_avg 0.8324 0.6395 0.7233 466
macro_avg 0.8110 0.4996 0.5930 466
weighted_avg 0.8292 0.6395 0.7132 466
【1/3 + 2/3-focal】
precision recall f1-score support
3 0.7890 0.8462 0.8166 221
2 0.7516 0.8939 0.8166 132
1 0.6935 0.6418 0.6667 67
6 0.3636 0.4000 0.3810 10
5 0.6538 0.5862 0.6182 29
0 0.4000 0.5000 0.4444 4
4 0.5000 0.6667 0.5714 3
micro_avg 0.7430 0.8004 0.7707 466
macro_avg 0.5931 0.6478 0.6164 466
weighted_avg 0.7420 0.8004 0.7686 466
【1/4-prior + 3/4-focal】
precision recall f1-score support
3 0.7956 0.8100 0.8027 221
2 0.7712 0.8939 0.8281 132
1 0.6981 0.5522 0.6167 67
6 0.6667 0.4000 0.5000 10
5 0.7143 0.5172 0.6000 29
0 0.3333 0.2500 0.2857 4
4 0.5000 0.6667 0.5714 3
micro_avg 0.7656 0.7639 0.7648 466
macro_avg 0.6399 0.5843 0.6007 466
weighted_avg 0.7610 0.7639 0.7581 466
【4/9-prior + 5/9-focal】
precision recall f1-score support
3 0.7819 0.8597 0.8190 221
2 0.7578 0.9242 0.8328 132
1 0.6567 0.6567 0.6567 67
6 0.5000 0.5000 0.5000 10
5 0.6250 0.5172 0.5660 29
0 0.2857 0.5000 0.3636 4
4 0.5000 0.6667 0.5714 3
micro_avg 0.7364 0.8155 0.7739 466
macro_avg 0.5867 0.6607 0.6156 466
weighted_avg 0.7352 0.8155 0.7715 466
希望对你有所帮助!