在系列文章NLP(三十四)使用keras-bert实现序列标注任务、NLP(三十五)使用keras-bert实现文本多分类任务、NLP(三十六)使用keras-bert实现文本多标签分类任务中,笔者介绍了如何使用keras-bert模块来调用BERT等模型来实现文本分类、文本多标签分类、序列标注任务。
  在系列文章NLP(二十二)利用ALBERT实现文本二分类、NLP(二十五)实现ALBERT+Bi-LSTM+CRF模型、NLP(二十八)多标签文本分类中,笔者将ALBERT模型作为特征向量提取工具,实现了文本分类、文本多标签分类、序列标注任务。
  本文将会介绍如何使用keras-bert调用ALBERT模型实现文本分类、文本多标签分类、序列标注任务,其模型效果比单纯将ALBERT模型作为特征向量提取工具的效果肯定来得好。
  使用keras-bert调用ALBERT模型实现文本多分类任务的Github项目网址:https://github.com/percent4/keras_albert_text_classification
  使用keras-bert调用ALBERT模型实现文本多标签分类任务的Github项目网址:https://github.com/percent4/keras_albert_multi_label_cls
  使用keras-bert调用ALBERT模型实现序列标注任务的Github项目网址:https://github.com/percent4/keras_albert_sequence_labeling

如何使用keras-bert调用ALBERT模型

  keras-bert模块的设计之初是为了支持BERT系列模型,它并不支持ALBERT模型。但在开源世界Github中有个项目名为keras_albert_model,其网址为:https://github.com/TinkerMob/keras_albert_model,利用这个项目,我们可以做到让keras-bert支持ALBERT模型。
  下载该项目中的albert.py脚本,我们使用如下示例代码来调用ALBERT-tiny模型,代码如下:

# -*- coding: utf-8 -*-
from albert import load_brightmart_albert_zh_checkpoint

model = load_brightmart_albert_zh_checkpoint('albert_xlarge_zh_183k', training=False)
model.summary()

输出的albert-tiny模型结构如下:

__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
Input-Token (InputLayer)        (None, 512)          0                                            
__________________________________________________________________________________________________
Input-Segment (InputLayer)      (None, 512)          0                                            
__________________________________________________________________________________________________
Embed-Token (AdaptiveEmbedding) [(None, 512, 312), ( 2744320     Input-Token[0][0]                
__________________________________________________________________________________________________
Embed-Segment (Embedding)       (None, 512, 312)     624         Input-Segment[0][0]              
__________________________________________________________________________________________________
Embed-Token-Segment (Add)       (None, 512, 312)     0           Embed-Token[0][0]                
                                                                 Embed-Segment[0][0]              
__________________________________________________________________________________________________
Embedding-Position (PositionEmb (None, 512, 312)     159744      Embed-Token-Segment[0][0]        
__________________________________________________________________________________________________
Embedding-Norm (LayerNormalizat (None, 512, 312)     624         Embedding-Position[0][0]         
__________________________________________________________________________________________________
Attention (MultiHeadAttention)  (None, 512, 312)     390624      Embedding-Norm[0][0]             
                                                                 Feed-Forward-Normal[0][0]        
                                                                 Feed-Forward-Normal[1][0]        
                                                                 Feed-Forward-Normal[2][0]        
__________________________________________________________________________________________________
Attention-Add-1 (Add)           (None, 512, 312)     0           Embedding-Norm[0][0]             
                                                                 Attention[0][0]                  
__________________________________________________________________________________________________
Attention-Normal (LayerNormaliz (None, 512, 312)     624         Attention-Add-1[0][0]            
                                                                 Attention-Add-2[0][0]            
                                                                 Attention-Add-3[0][0]            
                                                                 Attention-Add-4[0][0]            
__________________________________________________________________________________________________
Feed-Forward (FeedForward)      (None, 512, 312)     780312      Attention-Normal[0][0]           
                                                                 Attention-Normal[1][0]           
                                                                 Attention-Normal[2][0]           
                                                                 Attention-Normal[3][0]           
__________________________________________________________________________________________________
Feed-Forward-Add-1 (Add)        (None, 512, 312)     0           Attention-Normal[0][0]           
                                                                 Feed-Forward[0][0]               
__________________________________________________________________________________________________
Feed-Forward-Normal (LayerNorma (None, 512, 312)     624         Feed-Forward-Add-1[0][0]         
                                                                 Feed-Forward-Add-2[0][0]         
                                                                 Feed-Forward-Add-3[0][0]         
                                                                 Feed-Forward-Add-4[0][0]         
__________________________________________________________________________________________________
Attention-Add-2 (Add)           (None, 512, 312)     0           Feed-Forward-Normal[0][0]        
                                                                 Attention[1][0]                  
__________________________________________________________________________________________________
Feed-Forward-Add-2 (Add)        (None, 512, 312)     0           Attention-Normal[1][0]           
                                                                 Feed-Forward[1][0]               
__________________________________________________________________________________________________
Attention-Add-3 (Add)           (None, 512, 312)     0           Feed-Forward-Normal[1][0]        
                                                                 Attention[2][0]                  
__________________________________________________________________________________________________
Feed-Forward-Add-3 (Add)        (None, 512, 312)     0           Attention-Normal[2][0]           
                                                                 Feed-Forward[2][0]               
__________________________________________________________________________________________________
Attention-Add-4 (Add)           (None, 512, 312)     0           Feed-Forward-Normal[2][0]        
                                                                 Attention[3][0]                  
__________________________________________________________________________________________________
Feed-Forward-Add-4 (Add)        (None, 512, 312)     0           Attention-Normal[3][0]           
                                                                 Feed-Forward[3][0]               
==================================================================================================
Total params: 4,077,496
Trainable params: 0
Non-trainable params: 4,077,496
__________________________________________________________________________________________________

  有了它,我们可以在keras-bert模块中轻松愉快地调用ALBERT模型了。以下为keras-bert调用ALBERT模型,实现文本多分类、文本多标签分类以及序列标注任务,其代码与之前的keras-bert调用BERT模型的代码大体一致,只不过在加载预训练模型的时候需要keras_albert_model项目的帮助。
  以下将不再给出具体的项目代码,而只是给出keras-bert在调用ALBERT模型时,在不同NLP任务的表现,即模型评估效果。

文本多分类

  使用Keras和ALBERT实现文本多分类任务,其中对ALBERT进行微调。数据集为sougou小分类数据集。模型参数为batch_size = 8, maxlen = 300, epoch=3。

  • albert-tiny的模型评估结果
precision   recall    f1-score   support

          体育     0.9700    0.9798    0.9749        99
          健康     0.9278    0.9091    0.9184        99
          军事     0.9899    0.9899    0.9899        99
          教育     0.8585    0.9192    0.8878        99
          汽车     1.0000    0.9394    0.9688        99

    accuracy                         0.9475       495
   macro avg     0.9492    0.9475    0.9479       495
weighted avg     0.9492    0.9475    0.9479       495
  • albert_base_zh_additional_36k_steps的模型评估结果
precision    recall  f1-score   support

          体育     0.9802    1.0000    0.9900        99
          健康     0.9684    0.9293    0.9485        99
          军事     1.0000    0.9899    0.9949        99
          教育     0.8739    0.9798    0.9238        99
          汽车     1.0000    0.9091    0.9524        99

    accuracy                         0.9616       495
   macro avg     0.9645    0.9616    0.9619       495
weighted avg     0.9645    0.9616    0.9619       495
  • lbert_xlarge_zh_183k的模型评估结果
precision    recall  f1-score   support

          体育     0.9898    0.9798    0.9848        99
          健康     0.9412    0.9697    0.9552        99
          军事     0.9706    1.0000    0.9851        99
          教育     0.9300    0.9394    0.9347        99
          汽车     0.9892    0.9293    0.9583        99

    accuracy                         0.9636       495
   macro avg     0.9642    0.9636    0.9636       495
weighted avg     0.9642    0.9636    0.9636       495

文本多标签

  使用采用Keras和ALBERT实现文本多标签分类任务,其中对ALBERT进行微调。以2020语言与智能技术竞赛:事件抽取任务 中的数据作为多分类标签的样例数据,借助多标签分类模型来解决。模型参数为batch_size = 16, maxlen = 256, epoch=10。

  • albert-tiny的模型评估结果
micro avg     0.9488    0.8606    0.9025      1657
   macro avg     0.9446    0.8084    0.8589      1657
weighted avg     0.9460    0.8606    0.8955      1657
 samples avg     0.8932    0.8795    0.8799      1657

accuracy:  0.828437917222964
hamming loss:  0.0031631919482386773
  • albert_base_zh_additional_36k_steps的模型评估结果
micro avg     0.9471    0.9294    0.9382      1657
   macro avg     0.9416    0.9105    0.9208      1657
weighted avg     0.9477    0.9294    0.9362      1657
 samples avg     0.9436    0.9431    0.9379      1657

accuracy:  0.8931909212283045
hamming loss:  0.0020848310567936736

序列标注

  使用本项目采用Keras和ALBERT实现序列标注,其中对ALBERT进行微调。数据集为人民日报命名实体识别数据集、时间识别数据集、CLUENER细粒度实体识别数据集。

  • 人民日报命名实体识别数据集

1.1 albert-tiny

模型参数:MAX_SEQ_LEN=128, BATCH_SIZE=32, EPOCH=10

运行model_evaluate.py,模型评估结果如下:

precision    recall  f1-score   support

      LOC     0.8266    0.8171    0.8218      3658
      ORG     0.7289    0.7863    0.7565      2185
      PER     0.8865    0.8712    0.8788      1864

micro avg     0.8111    0.8215    0.8163      7707
macro avg     0.8134    0.8215    0.8171      7707

1.2 albert-base

模型参数:MAX_SEQ_LEN=128, BATCH_SIZE=32, EPOCH=10

运行model_evaluate.py,模型评估结果如下:

precision    recall  f1-score   support

      LOC     0.9032    0.8671    0.8848      3658
      PER     0.9270    0.9067    0.9167      1864
      ORG     0.8445    0.8549    0.8497      2185

micro avg     0.8917    0.8732    0.8824      7707
macro avg     0.8923    0.8732    0.8826      7707
  • 时间识别数据集

2.1 albert-tiny

模型参数:MAX_SEQ_LEN=256, BATCH_SIZE=8, EPOCH=10

运行model_evaluate.py,模型评估结果如下:

precision    recall  f1-score   support

     TIME     0.7924    0.8481    0.8193       441

micro avg     0.7924    0.8481    0.8193       441
macro avg     0.7924    0.8481    0.8193       441

2.2 albert-base

模型参数:MAX_SEQ_LEN=256, BATCH_SIZE=8, EPOCH=10

运行model_evaluate.py,模型评估结果如下:

precision    recall  f1-score   support

     TIME     0.8136    0.8413    0.8272       441

micro avg     0.8136    0.8413    0.8272       441
macro avg     0.8136    0.8413    0.8272       441
  • CLUENER细粒度实体识别数据集

3.1 albert-tiny

模型参数:MAX_SEQ_LEN=128, BATCH_SIZE=32, EPOCH=10

运行model_evaluate.py,模型评估结果如下:

precision    recall  f1-score   support

     company     0.5745    0.6639    0.6160       366
organization     0.5677    0.6337    0.5989       344
        game     0.6616    0.7561    0.7057       287
    position     0.6478    0.7012    0.6734       425
  government     0.6237    0.7336    0.6742       244
        name     0.6520    0.7894    0.7141       451
       movie     0.6164    0.6533    0.6343       150
       scene     0.5166    0.5477    0.5317       199
        book     0.6140    0.6908    0.6502       152
     address     0.4071    0.4698    0.4362       364

   micro avg     0.5884    0.6687    0.6260      2982
   macro avg     0.5881    0.6687    0.6255      2982

3.2 albert-base

模型参数:MAX_SEQ_LEN=128, BATCH_SIZE=32, EPOCH=10

运行model_evaluate.py,模型评估结果如下:

precision    recall  f1-score   support

        name     0.8419    0.8381    0.8400       451
     company     0.7161    0.7650    0.7398       366
    position     0.7205    0.7459    0.7329       425
     address     0.5473    0.5879    0.5669       364
        game     0.7033    0.8258    0.7596       287
        book     0.7931    0.7566    0.7744       152
       scene     0.6243    0.5930    0.6082       199
organization     0.6711    0.7297    0.6992       344
       movie     0.7051    0.7333    0.7190       150
  government     0.7567    0.8156    0.7850       244

   micro avg     0.7078    0.7441    0.7255      2982
   macro avg     0.7093    0.7441    0.7257      2982

不同版本ALBERT模型与BERT模型的参数对比

  以文本多分类任务为例,我们的数据集为sougou小分类数据集,文本最大长度为300,不同版本ALBERT模型与BERT模型的参数为:

model name

Total params

Trainable params

albert_tiny

4,079,061

4,079,061

albert_base_zh_additional_36k_steps

10,290,693

10,290,693

albert_xlarge_zh_183k

54,391,813

54,391,813

chinese_L-12_H-768_A-12

101,680,901

101,680,901

  通过上述对比,我们不难发现,即使是ALBERT的large模型,其参数量也比BERT的base版本来的少,这是由于ALBERT模型的结构决定的。

总结

  本文介绍了如何使用keras-bert调用ALBERT模型实现文本分类、文本多标签分类、序列标注任务,其Github项目地址已经在文章开头给出。
  在模型评估时,不少任务都未给出albert_xlarge_zh_183k模型的评估结果,这是由于GPU机器的性能限制,而不是笔者不愿做或偷懒,希望读者理解,同时也希望读者能有机会弥补这个遗憾。
  感谢大家的阅读,也感谢所有为开源项目作出贡献的人~