一#train

# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
import paddlehub as hub
import os, io, csv
from paddlehub.datasets.base_nlp_dataset import InputExample, TextClassificationDataset

class ThuNews(TextClassificationDataset):
    def __init__(self, tokenizer, mode='train', max_seq_len=128):
        if mode == 'train':
            data_file = 'train.txt'
        elif mode == 'test':
            data_file = 'test.txt'
        else:
            data_file = 'valid.txt'
        super(ThuNews, self).__init__(
            base_path=DATA_DIR,
            data_file=data_file,
            tokenizer=tokenizer,
            max_seq_len=max_seq_len,
            mode=mode,
            is_file_with_header=True,
            label_list=['体育', '科技', '社会', '娱乐', '股票', '房产', '教育', '时政', '财经', '星座', '游戏', '家居', '彩票', '时尚'])

    # 解析文本文件里的样本
    def _read_file(self, input_file, is_file_with_header: bool = False):
        if not os.path.exists(input_file):
            raise RuntimeError("The file {} is not found.".format(input_file))
        else:
            with io.open(input_file, "r", encoding="UTF-8") as f:
                reader = csv.reader(f, delimiter="\t", quotechar=None)
                examples = []
                seq_id = 0
                header = next(reader) if is_file_with_header else None
                for line in reader:
                    example = InputExample(guid=seq_id, text_a=line[0], label=line[1])
                    seq_id += 1
                    examples.append(example)
                return examples



if __name__ == '__main__':
    model = hub.Module(name='ernie_tiny', version='2.0.1', task='seq-cls', num_classes=14) # 在多分类任务中,num_classes需要显式地指定类别数,此处根据数据集设置为14
    # 通过以上的一行代码,model初始化为一个适用于文本分类任务的模型,为ERNIE的预训练模型后拼接上一个全连接网络(Full
    # Connected)。
    # 数据集存放位置
    DATA_DIR = "./thu_news"
    train_dataset = ThuNews(model.get_tokenizer(), mode='train', max_seq_len=128)
    dev_dataset = ThuNews(model.get_tokenizer(), mode='dev', max_seq_len=128)
    test_dataset = ThuNews(model.get_tokenizer(), mode='test', max_seq_len=128)
    for e in train_dataset.examples[:3]:
        print(e)
    optimizer = paddle.optimizer.Adam(learning_rate=5e-5, parameters=model.parameters())  # 优化器的选择和参数配置
    trainer = hub.Trainer(model, optimizer, checkpoint_dir='./ckpt', use_gpu=False)  # fine-tune任务的执行者
    trainer.train(train_dataset, epochs=3, batch_size=32, eval_dataset=dev_dataset,save_interval=1)  # 配置训练参数,启动训练,并指定验证集
    result = trainer.evaluate(test_dataset, batch_size=32)  # 在测试集上评估当前训练模型

二#predict

import paddlehub as hub

data = [
    # 房产
    ["昌平京基鹭府10月29日推别墅1200万套起享97折  新浪房产讯(编辑郭彪)京基鹭府(论坛相册户型样板间点评地图搜索)售楼处位于昌平区京承高速北七家出口向西南公里路南。项目预计10月29日开盘,总价1200万元/套起,2012年年底入住。待售户型为联排户型面积为410-522平方米,独栋户型面积为938平方米,双拼户型面积为522平方米。  京基鹭府项目位于昌平定泗路与东北路交界处。项目周边配套齐全,幼儿园:伊顿双语幼儿园、温莎双语幼儿园;中学:北师大亚太实验学校、潞河中学(北京市重点);大学:王府语言学校、北京邮电大学、现代音乐学院;医院:王府中西医结合医院(三级甲等)、潞河医院、解放军263医院、安贞医院昌平分院;购物:龙德广场、中联万家商厦、世纪华联超市、瑰宝购物中心、家乐福超市;酒店:拉斐特城堡、鲍鱼岛;休闲娱乐设施:九华山庄、温都温泉度假村、小汤山疗养院、龙脉温泉度假村、小汤山文化广场、皇港高尔夫、高地高尔夫、北鸿高尔夫球场;银行:工商银行、建设银行、中国银行、北京农村商业银行;邮局:中国邮政储蓄;其它:北七家建材城、百安居建材超市、北七家镇武装部、北京宏翔鸿企业孵化基地等,享受便捷生活。"],
    # 游戏
    ["尽管官方到今天也没有公布《使命召唤:现代战争2》的游戏详情,但《使命召唤:现代战争2》首部包含游戏画面的影片终于现身。虽然影片仅有短短不到20秒,但影片最后承诺大家将于美国时间5月24日NBA职业篮球东区决赛时将会揭露更多的游戏内容。  这部只有18秒的广告片闪现了9个镜头,能够辨识的场景有直升机飞向海岛军事工事,有飞机场争夺战,有潜艇和水下工兵,有冰上乘具,以及其他的一些镜头。整体来看《现代战争2》很大可能仍旧与俄罗斯有关。  片尾有一则预告:“May24th,EasternConferenceFinals”,这是什么?这是说当前美国NBA联赛东部总决赛的日期。原来这部视频是NBA季后赛奥兰多魔术对波士顿凯尔特人队时,TNT电视台播放的广告。"],
    # 体育
    ["罗马锋王竟公然挑战两大旗帜拉涅利的球队到底错在哪  记者张恺报道主场一球小胜副班长巴里无可吹捧,罗马占优也纯属正常,倒是托蒂罚失点球和前两号门将先后受伤(多尼以三号身份出场)更让人揪心。阵容规模扩大,反而表现不如上赛季,缺乏一流强队的色彩,这是所有球迷对罗马的印象。  拉涅利说:“去年我们带着嫉妒之心看国米,今年我们也有了和国米同等的超级阵容,许多教练都想有罗马的球员。阵容广了,寻找队内平衡就难了,某些时段球员的互相排斥和跟从前相比的落差都正常。有好的一面,也有不好的一面,所幸,我们一直在说一支伟大的罗马,必胜的信念和够级别的阵容,我们有了。”拉涅利的总结由近一阶段困扰罗马的队内摩擦、个别球员闹意见要走人而发,本赛季技术层面强化的罗马一直没有上赛季反扑的面貌,内部变化值得球迷关注。"],
    # 教育
    ["新总督致力提高加拿大公立教育质量  滑铁卢大学校长约翰斯顿先生于10月1日担任加拿大总督职务。约翰斯顿先生还曾任麦吉尔大学长,并曾在多伦多大学、女王大学和西安大略大学担任教学职位。  约翰斯顿先生在就职演说中表示,要将加拿大建设成为一个“聪明与关爱的国度”。为实现这一目标,他提出三个支柱:支持并关爱家庭、儿童;鼓励学习与创造;提倡慈善和志愿者精神。他尤其强调要关爱并尊重教师,并通过公立教育使每个人的才智得到充分发展。"]
]

label_list=['体育', '科技', '社会', '娱乐', '股票', '房产', '教育', '时政', '财经', '星座', '游戏', '家居', '彩票', '时尚']
label_map = {
    idx: label_text for idx, label_text in enumerate(label_list)
}

model = hub.Module(
    name='ernie_tiny',
    version='2.0.1',
    task='seq-cls',
    load_checkpoint='./ckpt/best_model/model.pdparams',
    label_map=label_map)
results = model.predict(data, max_seq_len=128, batch_size=1, use_gpu=False)
for idx, text in enumerate(data):
    print('Data: {} \t Lable: {}'.format(text[0], results[idx]))

#三流程说明

PaddleHub2.0——使用动态图版预训练模型ERNIE实现文新闻本分类
本项目将演示如何使用PaddleHub语义预训练模型ERNIE对自定义数据集完成文本分类。

请务必使用GPU环境, 因为下方的代码基于GPU环境.



当前平台正在进行普遍赠送, 只要点击此处表单进行填写, 之后再度运行即可获赠.

一、简介
在2017年之前,工业界和学术界对NLP文本处理依赖于序列模型Recurrent Neural Network (RNN).



近年来随着深度学习的发展,模型参数数量飞速增长,为了训练这些参数,需要更大的数据集来避免过拟合。然而,对于大部分NLP任务来说,构建大规模的标注数据集成本过高,非常困难,特别是对于句法和语义相关的任务。相比之下,大规模的未标注语料库的构建则相对容易。最近的研究表明,基于大规模未标注语料库的预训练模型(Pretrained Models, PTM) 能够习得通用的语言表示,将预训练模型Fine-tune到下游任务,能够获得出色的表现。另外,预训练模型能够避免从零开始训练模型。



二、准备工作
首先安装和导入必要的python包

In [1]
!pip install -U paddlehub -i https://pypi.tuna.tsinghua.edu.cn/simple
Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Requirement already satisfied: paddlehub in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (2.0.0rc0)
Collecting paddlehub
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/7a/29/3bd0ca43c787181e9c22fe44b944b64d7fcb14ce66d3bf4602d9ad2ac76c/paddlehub-2.1.0-py3-none-any.whl (211 kB)
     |████████████████████████████████| 211 kB 11.1 MB/s eta 0:00:01
Requirement already satisfied: pyzmq in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlehub) (18.1.1)
Requirement already satisfied: flask>=1.1.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlehub) (1.1.1)
Requirement already satisfied: opencv-python in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlehub) (4.1.1.26)
Requirement already satisfied: gunicorn>=19.10.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlehub) (20.0.4)
Requirement already satisfied: filelock in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlehub) (3.0.12)
Requirement already satisfied: tqdm in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlehub) (4.36.1)
Requirement already satisfied: rarfile in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlehub) (3.1)
Requirement already satisfied: numpy in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlehub) (1.20.2)
Requirement already satisfied: visualdl>=2.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlehub) (2.1.1)
Requirement already satisfied: pyyaml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlehub) (5.1.2)
Requirement already satisfied: easydict in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlehub) (1.9)
Collecting paddle2onnx>=0.5.1
  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/3b/77/7111bee0ebafcb940cf9749b3ebf3b2b2113ac44326918f45a2b872a1586/paddle2onnx-0.5.1-py3-none-any.whl (93 kB)
     |████████████████████████████████| 93 kB 9.0 MB/s  eta 0:00:01
Requirement already satisfied: gitpython in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlehub) (3.1.14)
Requirement already satisfied: colorlog in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlehub) (4.1.0)
Requirement already satisfied: paddlenlp>=2.0.0rc5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlehub) (2.0.0rc7)
Requirement already satisfied: matplotlib in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlehub) (2.2.3)
Requirement already satisfied: colorama in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlehub) (0.4.4)
Requirement already satisfied: Pillow in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlehub) (7.1.2)
Requirement already satisfied: packaging in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlehub) (20.9)
Requirement already satisfied: click>=5.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.0->paddlehub) (7.0)
Requirement already satisfied: Werkzeug>=0.15 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.0->paddlehub) (0.16.0)
Requirement already satisfied: itsdangerous>=0.24 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.0->paddlehub) (1.1.0)
Requirement already satisfied: Jinja2>=2.10.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flask>=1.1.0->paddlehub) (2.10.1)
Requirement already satisfied: setuptools>=3.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from gunicorn>=19.10.0->paddlehub) (56.0.0)
Requirement already satisfied: MarkupSafe>=0.23 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Jinja2>=2.10.1->flask>=1.1.0->paddlehub) (1.1.1)
Requirement already satisfied: six in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddle2onnx>=0.5.1->paddlehub) (1.15.0)
Requirement already satisfied: protobuf in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddle2onnx>=0.5.1->paddlehub) (3.14.0)
Requirement already satisfied: seqeval in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlenlp>=2.0.0rc5->paddlehub) (1.2.2)
Requirement already satisfied: h5py in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlenlp>=2.0.0rc5->paddlehub) (2.9.0)
Requirement already satisfied: jieba in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from paddlenlp>=2.0.0rc5->paddlehub) (0.42.1)
Requirement already satisfied: bce-python-sdk in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl>=2.0.0->paddlehub) (0.8.53)
Requirement already satisfied: shellcheck-py in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl>=2.0.0->paddlehub) (0.7.1.1)
Requirement already satisfied: flake8>=3.7.9 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl>=2.0.0->paddlehub) (3.8.2)
Requirement already satisfied: requests in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl>=2.0.0->paddlehub) (2.22.0)
Requirement already satisfied: pre-commit in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl>=2.0.0->paddlehub) (1.21.0)
Requirement already satisfied: Flask-Babel>=1.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from visualdl>=2.0.0->paddlehub) (1.0.0)
Requirement already satisfied: pycodestyle<2.7.0,>=2.6.0a1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl>=2.0.0->paddlehub) (2.6.0)
Requirement already satisfied: importlib-metadata in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl>=2.0.0->paddlehub) (0.23)
Requirement already satisfied: mccabe<0.7.0,>=0.6.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl>=2.0.0->paddlehub) (0.6.1)
Requirement already satisfied: pyflakes<2.3.0,>=2.2.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from flake8>=3.7.9->visualdl>=2.0.0->paddlehub) (2.2.0)
Requirement already satisfied: Babel>=2.3 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Flask-Babel>=1.0.0->visualdl>=2.0.0->paddlehub) (2.8.0)
Requirement already satisfied: pytz in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from Flask-Babel>=1.0.0->visualdl>=2.0.0->paddlehub) (2019.3)
Requirement already satisfied: pycryptodome>=3.8.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from bce-python-sdk->visualdl>=2.0.0->paddlehub) (3.9.9)
Requirement already satisfied: future>=0.6.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from bce-python-sdk->visualdl>=2.0.0->paddlehub) (0.18.0)
Requirement already satisfied: gitdb<5,>=4.0.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from gitpython->paddlehub) (4.0.5)
Requirement already satisfied: smmap<4,>=3.0.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from gitdb<5,>=4.0.1->gitpython->paddlehub) (3.0.5)
Requirement already satisfied: zipp>=0.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from importlib-metadata->flake8>=3.7.9->visualdl>=2.0.0->paddlehub) (0.6.0)
Requirement already satisfied: more-itertools in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from zipp>=0.5->importlib-metadata->flake8>=3.7.9->visualdl>=2.0.0->paddlehub) (7.2.0)
Requirement already satisfied: kiwisolver>=1.0.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib->paddlehub) (1.1.0)
Requirement already satisfied: cycler>=0.10 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib->paddlehub) (0.10.0)
Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib->paddlehub) (2.4.2)
Requirement already satisfied: python-dateutil>=2.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from matplotlib->paddlehub) (2.8.0)
Requirement already satisfied: identify>=1.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl>=2.0.0->paddlehub) (1.4.10)
Requirement already satisfied: toml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl>=2.0.0->paddlehub) (0.10.0)
Requirement already satisfied: virtualenv>=15.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl>=2.0.0->paddlehub) (16.7.9)
Requirement already satisfied: cfgv>=2.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl>=2.0.0->paddlehub) (2.0.1)
Requirement already satisfied: aspy.yaml in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl>=2.0.0->paddlehub) (1.3.0)
Requirement already satisfied: nodeenv>=0.11.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from pre-commit->visualdl>=2.0.0->paddlehub) (1.3.4)
Requirement already satisfied: idna<2.9,>=2.5 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl>=2.0.0->paddlehub) (2.8)
Requirement already satisfied: chardet<3.1.0,>=3.0.2 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl>=2.0.0->paddlehub) (3.0.4)
Requirement already satisfied: certifi>=2017.4.17 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl>=2.0.0->paddlehub) (2019.9.11)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from requests->visualdl>=2.0.0->paddlehub) (1.25.6)
Requirement already satisfied: scikit-learn>=0.21.3 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from seqeval->paddlenlp>=2.0.0rc5->paddlehub) (0.24.1)
Requirement already satisfied: scipy>=0.19.1 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-learn>=0.21.3->seqeval->paddlenlp>=2.0.0rc5->paddlehub) (1.6.2)
Requirement already satisfied: threadpoolctl>=2.0.0 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-learn>=0.21.3->seqeval->paddlenlp>=2.0.0rc5->paddlehub) (2.1.0)
Requirement already satisfied: joblib>=0.11 in /opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages (from scikit-learn>=0.21.3->seqeval->paddlenlp>=2.0.0rc5->paddlehub) (0.14.1)
Installing collected packages: paddle2onnx, paddlehub
  Attempting uninstall: paddlehub
    Found existing installation: paddlehub 2.0.0rc0
    Uninstalling paddlehub-2.0.0rc0:
      Successfully uninstalled paddlehub-2.0.0rc0
Successfully installed paddle2onnx-0.5.1 paddlehub-2.1.0
WARNING: You are using pip version 21.0.1; however, version 21.1 is available.
You should consider upgrading via the '/opt/conda/envs/python35-paddle120-env/bin/python -m pip install --upgrade pip' command.
In [2]
import paddlehub as hub
import paddle
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/layers/utils.py:26: DeprecationWarning: `np.int` is a deprecated alias for the builtin `int`. To silence this warning, use `int` by itself. Doing this will not modify any behavior and is safe. When replacing `np.int`, you may wish to use e.g. `np.int64` or `np.int32` to specify the precision. If you wish to review your current use, check the release note link for additional information.
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  def convert_to_list(value, n, name, dtype=np.int):
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/__init__.py:107: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import MutableMapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/rcsetup.py:20: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import Iterable, Mapping
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/matplotlib/colors.py:53: DeprecationWarning: Using or importing the ABCs from 'collections' instead of from 'collections.abc' is deprecated, and in 3.8 it will stop working
  from collections import Sized
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle2onnx/onnx_helper/mapping.py:42: DeprecationWarning: `np.object` is a deprecated alias for the builtin `object`. To silence this warning, use `object` by itself. Doing this will not modify any behavior and is safe. 
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  int(TensorProto.STRING): np.dtype(np.object)
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle2onnx/constant/dtypes.py:43: DeprecationWarning: `np.bool` is a deprecated alias for the builtin `bool`. To silence this warning, use `bool` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.bool_` here.
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  np.bool: core.VarDesc.VarType.BOOL,
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle2onnx/constant/dtypes.py:44: DeprecationWarning: `np.float` is a deprecated alias for the builtin `float`. To silence this warning, use `float` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.float64` here.
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  core.VarDesc.VarType.FP32: np.float,
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle2onnx/constant/dtypes.py:49: DeprecationWarning: `np.bool` is a deprecated alias for the builtin `bool`. To silence this warning, use `bool` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.bool_` here.
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  core.VarDesc.VarType.BOOL: np.bool
三、代码步骤
使用PaddleHub Fine-tune API进行Fine-tune可以分为4个步骤。

选择模型
加载自定义数据集
选择优化策略和运行配置
执行fine-tune并评估模型
Step1: 选择模型
In [3]
model = hub.Module(name="ernie", task='seq-cls', num_classes=14) # 在多分类任务中,num_classes需要显式地指定类别数,此处根据数据集设置为14
[2021-04-29 11:23:25,974] [    INFO] - Already cached /home/aistudio/.paddlenlp/models/ernie-1.0/ernie_v1_chn_base.pdparams
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for classifier.weight. classifier.weight is not found in the provided dict.
  warnings.warn(("Skip loading for {}. ".format(key) + str(err)))
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/dygraph/layers.py:1303: UserWarning: Skip loading for classifier.bias. classifier.bias is not found in the provided dict.
  warnings.warn(("Skip loading for {}. ".format(key) + str(err)))
hub.Module的参数用法如下:

name:模型名称,可以选择ernie,ernie_tiny,bert-base-cased, bert-base-chinese, roberta-wwm-ext,roberta-wwm-ext-large等。
task:fine-tune任务。此处为seq-cls,表示文本分类任务。
num_classes:表示当前文本分类任务的类别数,根据具体使用的数据集确定,默认为2。
NOTE: 文本多分类的任务中,num_classes需要用户指定,具体的类别数根据选用的数据集确定,本教程中为14。

 

PaddleHub还提供BERT等模型可供选择, 当前支持文本分类任务的模型对应的加载示例如下:

模型名	PaddleHub Module
ERNIE, Chinese	hub.Module(name='ernie')
ERNIE tiny, Chinese	hub.Module(name='ernie_tiny')
ERNIE 2.0 Base, English	hub.Module(name='ernie_v2_eng_base')
ERNIE 2.0 Large, English	hub.Module(name='ernie_v2_eng_large')
BERT-Base, English Cased	hub.Module(name='bert-base-cased')
BERT-Base, English Uncased	hub.Module(name='bert-base-uncased')
BERT-Large, English Cased	hub.Module(name='bert-large-cased')
BERT-Large, English Uncased	hub.Module(name='bert-large-uncased')
BERT-Base, Multilingual Cased	hub.Module(nane='bert-base-multilingual-cased')
BERT-Base, Multilingual Uncased	hub.Module(nane='bert-base-multilingual-uncased')
BERT-Base, Chinese	hub.Module(name='bert-base-chinese')
BERT-wwm, Chinese	hub.Module(name='chinese-bert-wwm')
BERT-wwm-ext, Chinese	hub.Module(name='chinese-bert-wwm-ext')
RoBERTa-wwm-ext, Chinese	hub.Module(name='roberta-wwm-ext')
RoBERTa-wwm-ext-large, Chinese	hub.Module(name='roberta-wwm-ext-large')
RBT3, Chinese	hub.Module(name='rbt3')
RBTL3, Chinese	hub.Module(name='rbtl3')
ELECTRA-Small, English	hub.Module(name='electra-small')
ELECTRA-Base, English	hub.Module(name='electra-base')
ELECTRA-Large, English	hub.Module(name='electra-large')
ELECTRA-Base, Chinese	hub.Module(name='chinese-electra-base')
ELECTRA-Small, Chinese	hub.Module(name='chinese-electra-small')
通过以上的一行代码,model初始化为一个适用于文本分类任务的模型,为ERNIE的预训练模型后拼接上一个全连接网络(Full Connected)。 

以上图片来自于:https://arxiv.org/pdf/1810.04805.pdf

Step2: 加载自定义数据集
本示例数据集是由清华大学提供的新闻文本数据集THUCNews。THUCNews是根据新浪新闻RSS订阅频道2005~2011年间的历史数据筛选过滤生成,包含74万篇新闻文档(2.19 GB),均为UTF-8纯文本格式。我们在原始新浪新闻分类体系的基础上,重新整合划分出14个候选分类类别:财经、彩票、房产、股票、家居、教育、科技、社会、时尚、时政、体育、星座、游戏、娱乐。为了快速展示如何使用PaddleHub完成文本分类任务,该示例数据集从THUCNews训练集中随机抽取了9000条文本数据集作为本示例的训练集,从验证集中14个类别每个类别随机抽取100条数据作为本示例的验证集,测试集抽取方式和验证集相同。

首先解压数据集。

In [4]
# 查看当前挂载的数据集目录, 该目录下的变更重启环境后会自动还原
# View dataset directory. This directory will be recovered automatically after resetting environment.
%cd /home/aistudio/data/data16287/
!tar -zxvf thu_news.tar.gz
!ls -hl thu_news

!head -n 3 thu_news/train.txt
/home/aistudio/data/data16287
thu_news/
thu_news/test.txt
thu_news/valid.txt
thu_news/train.txt
total 30M
-rw-r--r-- 1 aistudio aistudio 3.7M Nov 19  2019 test.txt
-rw-r--r-- 1 aistudio aistudio  23M Nov 19  2019 train.txt
-rw-r--r-- 1 aistudio aistudio 3.6M Nov 19  2019 valid.txt
text_a	label


具体详情可参考 加载自定义数据集

In [5]
import os, io, csv
from paddlehub.datasets.base_nlp_dataset import InputExample, TextClassificationDataset

# 数据集存放位置
DATA_DIR="/home/aistudio/data/data16287/thu_news"
In [6]
class ThuNews(TextClassificationDataset):
    def __init__(self, tokenizer, mode='train', max_seq_len=128):
        if mode == 'train':
            data_file = 'train.txt'
        elif mode == 'test':
            data_file = 'test.txt'
        else:
            data_file = 'valid.txt'
        super(ThuNews, self).__init__(
            base_path=DATA_DIR,
            data_file=data_file,
            tokenizer=tokenizer,
            max_seq_len=max_seq_len,
            mode=mode,
            is_file_with_header=True,
            label_list=['体育', '科技', '社会', '娱乐', '股票', '房产', '教育', '时政', '财经', '星座', '游戏', '家居', '彩票', '时尚'])

    # 解析文本文件里的样本
    def _read_file(self, input_file, is_file_with_header: bool = False):
        if not os.path.exists(input_file):
            raise RuntimeError("The file {} is not found.".format(input_file))
        else:
            with io.open(input_file, "r", encoding="UTF-8") as f:
                reader = csv.reader(f, delimiter="\t", quotechar=None)
                examples = []
                seq_id = 0
                header = next(reader) if is_file_with_header else None
                for line in reader:
                    example = InputExample(guid=seq_id, text_a=line[0], label=line[1])
                    seq_id += 1
                    examples.append(example)
                return examples

train_dataset = ThuNews(model.get_tokenizer(), mode='train', max_seq_len=128)
dev_dataset = ThuNews(model.get_tokenizer(), mode='dev', max_seq_len=128)
test_dataset = ThuNews(model.get_tokenizer(), mode='test', max_seq_len=128)
for e in train_dataset.examples[:3]:
    print(e)
[2021-04-29 11:23:34,197] [    INFO] - Found /home/aistudio/.paddlenlp/models/ernie-1.0/vocab.txt
[2021-04-29 11:24:55,104] [    INFO] - Found /home/aistudio/.paddlenlp/models/ernie-1.0/vocab.txt
[2021-04-29 11:25:06,165] [    INFO] - Found /home/aistudio/.paddlenlp/models/ernie-1.0/vocab.txt
Step3: 选择优化策略和运行配置
In [7]
optimizer = paddle.optimizer.Adam(learning_rate=5e-5, parameters=model.parameters())  # 优化器的选择和参数配置
trainer = hub.Trainer(model, optimizer, checkpoint_dir='./ckpt', use_gpu=True)        # fine-tune任务的执行者
[2021-04-29 11:25:18,000] [    INFO] - PaddleHub model checkpoint loaded. current_epoch=3 [acc=0.9150]
优化策略
Paddle2.0-rc提供了多种优化器选择,如SGD, Adam, Adamax等,详细参见策略。

在本教程中选择了Adam优化器,其的参数用法:

learning_rate: 全局学习率。默认为1e-3;
parameters: 待优化模型参数。
运行配置
Trainer 主要控制Fine-tune任务的训练,是任务的发起者,包含以下可控制的参数:

model: 被优化模型;
optimizer: 优化器选择;
use_gpu: 是否使用gpu训练;
use_vdl: 是否使用vdl可视化训练过程;
checkpoint_dir: 保存模型参数的地址;
compare_metrics: 保存最优模型的衡量指标;
Step4: 执行fine-tune并评估模型
In [8]
trainer.train(train_dataset, epochs=3, batch_size=32, eval_dataset=dev_dataset, save_interval=1)   # 配置训练参数,启动训练,并指定验证集
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/fluid/dataloader/dataloader_iter.py:89: DeprecationWarning: `np.bool` is a deprecated alias for the builtin `bool`. To silence this warning, use `bool` by itself. Doing this will not modify any behavior and is safe. If you specifically wanted the numpy scalar type, use `np.bool_` here.
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if isinstance(slot[0], (np.ndarray, np.bool, numbers.Number)):
[2021-04-29 11:25:29,302] [   TRAIN] - Epoch=4/3, Step=10/282 loss=0.0421 acc=0.9844 lr=0.000050 step/sec=3.45 | ETA 00:04:05
[2021-04-29 11:25:31,403] [   TRAIN] - Epoch=4/3, Step=20/282 loss=0.0877 acc=0.9781 lr=0.000050 step/sec=4.76 | ETA 00:03:31
[2021-04-29 11:25:33,500] [   TRAIN] - Epoch=4/3, Step=30/282 loss=0.0495 acc=0.9844 lr=0.000050 step/sec=4.77 | ETA 00:03:20
[2021-04-29 11:25:35,594] [   TRAIN] - Epoch=4/3, Step=40/282 loss=0.0390 acc=0.9875 lr=0.000050 step/sec=4.78 | ETA 00:03:14
[2021-04-29 11:25:37,689] [   TRAIN] - Epoch=4/3, Step=50/282 loss=0.0252 acc=0.9969 lr=0.000050 step/sec=4.77 | ETA 00:03:10
[2021-04-29 11:25:39,784] [   TRAIN] - Epoch=4/3, Step=60/282 loss=0.0494 acc=0.9906 lr=0.000050 step/sec=4.78 | ETA 00:03:08
[2021-04-29 11:25:41,885] [   TRAIN] - Epoch=4/3, Step=70/282 loss=0.0437 acc=0.9906 lr=0.000050 step/sec=4.76 | ETA 00:03:07
[2021-04-29 11:25:43,991] [   TRAIN] - Epoch=4/3, Step=80/282 loss=0.0403 acc=0.9875 lr=0.000050 step/sec=4.75 | ETA 00:03:06
[2021-04-29 11:25:46,094] [   TRAIN] - Epoch=4/3, Step=90/282 loss=0.0357 acc=0.9812 lr=0.000050 step/sec=4.76 | ETA 00:03:05
[2021-04-29 11:25:48,192] [   TRAIN] - Epoch=4/3, Step=100/282 loss=0.0172 acc=0.9969 lr=0.000050 step/sec=4.77 | ETA 00:03:04
[2021-04-29 11:25:50,291] [   TRAIN] - Epoch=4/3, Step=110/282 loss=0.0085 acc=1.0000 lr=0.000050 step/sec=4.76 | ETA 00:03:03
[2021-04-29 11:25:52,398] [   TRAIN] - Epoch=4/3, Step=120/282 loss=0.0630 acc=0.9781 lr=0.000050 step/sec=4.75 | ETA 00:03:03
[2021-04-29 11:25:54,508] [   TRAIN] - Epoch=4/3, Step=130/282 loss=0.0633 acc=0.9781 lr=0.000050 step/sec=4.74 | ETA 00:03:02
[2021-04-29 11:25:56,614] [   TRAIN] - Epoch=4/3, Step=140/282 loss=0.0537 acc=0.9875 lr=0.000050 step/sec=4.75 | ETA 00:03:02
[2021-04-29 11:25:58,717] [   TRAIN] - Epoch=4/3, Step=150/282 loss=0.0568 acc=0.9906 lr=0.000050 step/sec=4.76 | ETA 00:03:02
[2021-04-29 11:26:00,825] [   TRAIN] - Epoch=4/3, Step=160/282 loss=0.0865 acc=0.9812 lr=0.000050 step/sec=4.74 | ETA 00:03:02
[2021-04-29 11:26:02,938] [   TRAIN] - Epoch=4/3, Step=170/282 loss=0.0438 acc=0.9844 lr=0.000050 step/sec=4.73 | ETA 00:03:01
[2021-04-29 11:26:05,060] [   TRAIN] - Epoch=4/3, Step=180/282 loss=0.0369 acc=0.9906 lr=0.000050 step/sec=4.71 | ETA 00:03:01
[2021-04-29 11:26:07,196] [   TRAIN] - Epoch=4/3, Step=190/282 loss=0.0418 acc=0.9906 lr=0.000050 step/sec=4.68 | ETA 00:03:01
[2021-04-29 11:26:09,318] [   TRAIN] - Epoch=4/3, Step=200/282 loss=0.0170 acc=1.0000 lr=0.000050 step/sec=4.71 | ETA 00:03:01
[2021-04-29 11:26:11,437] [   TRAIN] - Epoch=4/3, Step=210/282 loss=0.0922 acc=0.9812 lr=0.000050 step/sec=4.72 | ETA 00:03:01
[2021-04-29 11:26:13,559] [   TRAIN] - Epoch=4/3, Step=220/282 loss=0.0524 acc=0.9875 lr=0.000050 step/sec=4.71 | ETA 00:03:01
[2021-04-29 11:26:15,683] [   TRAIN] - Epoch=4/3, Step=230/282 loss=0.0400 acc=0.9906 lr=0.000050 step/sec=4.71 | ETA 00:03:01
[2021-04-29 11:26:17,803] [   TRAIN] - Epoch=4/3, Step=240/282 loss=0.0632 acc=0.9875 lr=0.000050 step/sec=4.72 | ETA 00:03:01
[2021-04-29 11:26:19,914] [   TRAIN] - Epoch=4/3, Step=250/282 loss=0.0436 acc=0.9812 lr=0.000050 step/sec=4.74 | ETA 00:03:01
[2021-04-29 11:26:22,036] [   TRAIN] - Epoch=4/3, Step=260/282 loss=0.0459 acc=0.9844 lr=0.000050 step/sec=4.71 | ETA 00:03:01
[2021-04-29 11:26:24,160] [   TRAIN] - Epoch=4/3, Step=270/282 loss=0.0819 acc=0.9812 lr=0.000050 step/sec=4.71 | ETA 00:03:00
[2021-04-29 11:26:26,275] [   TRAIN] - Epoch=4/3, Step=280/282 loss=0.0344 acc=0.9906 lr=0.000050 step/sec=4.73 | ETA 00:03:00
[2021-04-29 11:26:29,806] [    EVAL] - Evaluation on validation dataset: \ - Evaluation on validation dataset: | - Evaluation on validation dataset: / - Evaluation on validation dataset: - - Evaluation on validation dataset: \ - Evaluation on validation dataset: | - Evaluation on validation dataset: / - Evaluation on validation dataset: - - Evaluation on validation dataset: \ - Evaluation on validation dataset: | - Evaluation on validation dataset: / - Evaluation on validation dataset: - - Evaluation on validation dataset: \ - Evaluation on validation dataset: | - Evaluation on validation dataset: / - Evaluation on validation dataset: - - Evaluation on validation dataset: \ - Evaluation on validation dataset: | - Evaluation on validation dataset: / - Evaluation on validation dataset: - - Evaluation on validation dataset: \ - Evaluation on validation dataset: | - Evaluation on validation dataset: / - Evaluation on validation dataset: - - Evaluation on validation dataset: \ - Evaluation on validation dataset: | - Evaluation on validation dataset: / - Evaluation on validation dataset: - - Evaluation on validation dataset: \ - Evaluation on validation dataset: | - Evaluation on validation dataset: / - Evaluation on validation dataset: - - [Evaluation result] avg_acc=0.9036
[2021-04-29 11:26:29,808] [    INFO] - Saving model checkpoint to ./ckpt/epoch_4
[2021-04-29 11:26:42,820] [   TRAIN] - Epoch=5/3, Step=10/282 loss=0.0750 acc=0.9844 lr=0.000050 step/sec=0.73 | ETA 00:03:41
[2021-04-29 11:26:44,930] [   TRAIN] - Epoch=5/3, Step=20/282 loss=0.0442 acc=0.9844 lr=0.000050 step/sec=4.74 | ETA 00:03:39
[2021-04-29 11:26:47,043] [   TRAIN] - Epoch=5/3, Step=30/282 loss=0.0545 acc=0.9844 lr=0.000050 step/sec=4.73 | ETA 00:03:38
[2021-04-29 11:26:49,154] [   TRAIN] - Epoch=5/3, Step=40/282 loss=0.0527 acc=0.9875 lr=0.000050 step/sec=4.74 | ETA 00:03:37
[2021-04-29 11:26:51,270] [   TRAIN] - Epoch=5/3, Step=50/282 loss=0.0244 acc=0.9938 lr=0.000050 step/sec=4.73 | ETA 00:03:36
[2021-04-29 11:26:53,388] [   TRAIN] - Epoch=5/3, Step=60/282 loss=0.0735 acc=0.9844 lr=0.000050 step/sec=4.72 | ETA 00:03:35
[2021-04-29 11:26:55,512] [   TRAIN] - Epoch=5/3, Step=70/282 loss=0.0341 acc=0.9875 lr=0.000050 step/sec=4.71 | ETA 00:03:34
[2021-04-29 11:26:57,633] [   TRAIN] - Epoch=5/3, Step=80/282 loss=0.0251 acc=0.9938 lr=0.000050 step/sec=4.72 | ETA 00:03:33
[2021-04-29 11:26:59,752] [   TRAIN] - Epoch=5/3, Step=90/282 loss=0.0375 acc=0.9906 lr=0.000050 step/sec=4.72 | ETA 00:03:32
[2021-04-29 11:27:01,869] [   TRAIN] - Epoch=5/3, Step=100/282 loss=0.0132 acc=0.9969 lr=0.000050 step/sec=4.72 | ETA 00:03:31
[2021-04-29 11:27:03,994] [   TRAIN] - Epoch=5/3, Step=110/282 loss=0.0453 acc=0.9875 lr=0.000050 step/sec=4.71 | ETA 00:03:30
[2021-04-29 11:27:06,114] [   TRAIN] - Epoch=5/3, Step=120/282 loss=0.0600 acc=0.9781 lr=0.000050 step/sec=4.72 | ETA 00:03:29
[2021-04-29 11:27:08,241] [   TRAIN] - Epoch=5/3, Step=130/282 loss=0.0504 acc=0.9906 lr=0.000050 step/sec=4.70 | ETA 00:03:29
[2021-04-29 11:27:10,388] [   TRAIN] - Epoch=5/3, Step=140/282 loss=0.0186 acc=0.9938 lr=0.000050 step/sec=4.66 | ETA 00:03:28
[2021-04-29 11:27:12,533] [   TRAIN] - Epoch=5/3, Step=150/282 loss=0.0382 acc=0.9875 lr=0.000050 step/sec=4.66 | ETA 00:03:27
[2021-04-29 11:27:14,680] [   TRAIN] - Epoch=5/3, Step=160/282 loss=0.0534 acc=0.9844 lr=0.000050 step/sec=4.66 | ETA 00:03:27
[2021-04-29 11:27:16,823] [   TRAIN] - Epoch=5/3, Step=170/282 loss=0.0525 acc=0.9844 lr=0.000050 step/sec=4.67 | ETA 00:03:26
[2021-04-29 11:27:18,965] [   TRAIN] - Epoch=5/3, Step=180/282 loss=0.0342 acc=0.9875 lr=0.000050 step/sec=4.67 | ETA 00:03:26
[2021-04-29 11:27:21,107] [   TRAIN] - Epoch=5/3, Step=190/282 loss=0.0202 acc=0.9938 lr=0.000050 step/sec=4.67 | ETA 00:03:25
[2021-04-29 11:27:23,253] [   TRAIN] - Epoch=5/3, Step=200/282 loss=0.0423 acc=0.9938 lr=0.000050 step/sec=4.66 | ETA 00:03:25
[2021-04-29 11:27:25,395] [   TRAIN] - Epoch=5/3, Step=210/282 loss=0.0175 acc=0.9938 lr=0.000050 step/sec=4.67 | ETA 00:03:24
[2021-04-29 11:27:27,544] [   TRAIN] - Epoch=5/3, Step=220/282 loss=0.0453 acc=0.9938 lr=0.000050 step/sec=4.65 | ETA 00:03:24
[2021-04-29 11:27:29,690] [   TRAIN] - Epoch=5/3, Step=230/282 loss=0.0393 acc=0.9938 lr=0.000050 step/sec=4.66 | ETA 00:03:23
[2021-04-29 11:27:31,836] [   TRAIN] - Epoch=5/3, Step=240/282 loss=0.0332 acc=0.9906 lr=0.000050 step/sec=4.66 | ETA 00:03:23
[2021-04-29 11:27:33,981] [   TRAIN] - Epoch=5/3, Step=250/282 loss=0.0284 acc=0.9906 lr=0.000050 step/sec=4.66 | ETA 00:03:22
[2021-04-29 11:27:36,132] [   TRAIN] - Epoch=5/3, Step=260/282 loss=0.0508 acc=0.9906 lr=0.000050 step/sec=4.65 | ETA 00:03:22
[2021-04-29 11:27:38,282] [   TRAIN] - Epoch=5/3, Step=270/282 loss=0.0436 acc=0.9875 lr=0.000050 step/sec=4.65 | ETA 00:03:22
[2021-04-29 11:27:40,424] [   TRAIN] - Epoch=5/3, Step=280/282 loss=0.0674 acc=0.9844 lr=0.000050 step/sec=4.67 | ETA 00:03:21
[2021-04-29 11:27:44,102] [    EVAL] - Evaluation on validation dataset: \ - Evaluation on validation dataset: | - Evaluation on validation dataset: / - Evaluation on validation dataset: - - Evaluation on validation dataset: \ - Evaluation on validation dataset: | - Evaluation on validation dataset: / - Evaluation on validation dataset: - - Evaluation on validation dataset: \ - Evaluation on validation dataset: | - Evaluation on validation dataset: / - Evaluation on validation dataset: - - Evaluation on validation dataset: \ - Evaluation on validation dataset: | - Evaluation on validation dataset: / - Evaluation on validation dataset: - - Evaluation on validation dataset: \ - Evaluation on validation dataset: | - Evaluation on validation dataset: / - Evaluation on validation dataset: - - Evaluation on validation dataset: \ - Evaluation on validation dataset: | - Evaluation on validation dataset: / - Evaluation on validation dataset: - - Evaluation on validation dataset: \ - Evaluation on validation dataset: | - Evaluation on validation dataset: / - Evaluation on validation dataset: - - Evaluation on validation dataset: \ - Evaluation on validation dataset: | - Evaluation on validation dataset: / - Evaluation on validation dataset: - - Evaluation on validation dataset: \ - [Evaluation result] avg_acc=0.9186
[2021-04-29 11:27:55,820] [    EVAL] - Saving best model to ./ckpt/best_model [best acc=0.9186]
[2021-04-29 11:27:55,823] [    INFO] - Saving model checkpoint to ./ckpt/epoch_5
[2021-04-29 11:28:08,954] [   TRAIN] - Epoch=6/3, Step=10/282 loss=0.0388 acc=0.9875 lr=0.000050 step/sec=0.42 | ETA 00:03:59
[2021-04-29 11:28:11,062] [   TRAIN] - Epoch=6/3, Step=20/282 loss=0.0203 acc=0.9906 lr=0.000050 step/sec=4.74 | ETA 00:03:58
[2021-04-29 11:28:13,189] [   TRAIN] - Epoch=6/3, Step=30/282 loss=0.0434 acc=0.9844 lr=0.000050 step/sec=4.70 | ETA 00:03:57
[2021-04-29 11:28:15,319] [   TRAIN] - Epoch=6/3, Step=40/282 loss=0.0164 acc=0.9969 lr=0.000050 step/sec=4.69 | ETA 00:03:56
[2021-04-29 11:28:17,439] [   TRAIN] - Epoch=6/3, Step=50/282 loss=0.0146 acc=0.9969 lr=0.000050 step/sec=4.72 | ETA 00:03:55
[2021-04-29 11:28:19,556] [   TRAIN] - Epoch=6/3, Step=60/282 loss=0.0356 acc=0.9875 lr=0.000050 step/sec=4.72 | ETA 00:03:54
[2021-04-29 11:28:21,680] [   TRAIN] - Epoch=6/3, Step=70/282 loss=0.0325 acc=0.9875 lr=0.000050 step/sec=4.71 | ETA 00:03:53
[2021-04-29 11:28:23,809] [   TRAIN] - Epoch=6/3, Step=80/282 loss=0.0556 acc=0.9875 lr=0.000050 step/sec=4.70 | ETA 00:03:53
[2021-04-29 11:28:25,943] [   TRAIN] - Epoch=6/3, Step=90/282 loss=0.0646 acc=0.9812 lr=0.000050 step/sec=4.69 | ETA 00:03:52
[2021-04-29 11:28:28,058] [   TRAIN] - Epoch=6/3, Step=100/282 loss=0.0327 acc=0.9875 lr=0.000050 step/sec=4.73 | ETA 00:03:51
[2021-04-29 11:28:30,180] [   TRAIN] - Epoch=6/3, Step=110/282 loss=0.0733 acc=0.9781 lr=0.000050 step/sec=4.71 | ETA 00:03:50
[2021-04-29 11:28:32,320] [   TRAIN] - Epoch=6/3, Step=120/282 loss=0.0295 acc=0.9906 lr=0.000050 step/sec=4.67 | ETA 00:03:49
[2021-04-29 11:28:34,465] [   TRAIN] - Epoch=6/3, Step=130/282 loss=0.0270 acc=0.9906 lr=0.000050 step/sec=4.66 | ETA 00:03:49
[2021-04-29 11:28:36,608] [   TRAIN] - Epoch=6/3, Step=140/282 loss=0.0340 acc=0.9875 lr=0.000050 step/sec=4.67 | ETA 00:03:48
[2021-04-29 11:28:38,758] [   TRAIN] - Epoch=6/3, Step=150/282 loss=0.0288 acc=0.9906 lr=0.000050 step/sec=4.65 | ETA 00:03:47
[2021-04-29 11:28:40,908] [   TRAIN] - Epoch=6/3, Step=160/282 loss=0.0208 acc=0.9969 lr=0.000050 step/sec=4.65 | ETA 00:03:47
[2021-04-29 11:28:43,056] [   TRAIN] - Epoch=6/3, Step=170/282 loss=0.0114 acc=1.0000 lr=0.000050 step/sec=4.66 | ETA 00:03:46
[2021-04-29 11:28:45,202] [   TRAIN] - Epoch=6/3, Step=180/282 loss=0.0277 acc=0.9969 lr=0.000050 step/sec=4.66 | ETA 00:03:46
[2021-04-29 11:28:47,351] [   TRAIN] - Epoch=6/3, Step=190/282 loss=0.0099 acc=0.9969 lr=0.000050 step/sec=4.65 | ETA 00:03:45
[2021-04-29 11:28:49,482] [   TRAIN] - Epoch=6/3, Step=200/282 loss=0.0387 acc=0.9875 lr=0.000050 step/sec=4.69 | ETA 00:03:44
[2021-04-29 11:28:51,610] [   TRAIN] - Epoch=6/3, Step=210/282 loss=0.0090 acc=0.9969 lr=0.000050 step/sec=4.70 | ETA 00:03:44
[2021-04-29 11:28:53,739] [   TRAIN] - Epoch=6/3, Step=220/282 loss=0.0271 acc=0.9938 lr=0.000050 step/sec=4.70 | ETA 00:03:43
[2021-04-29 11:28:55,875] [   TRAIN] - Epoch=6/3, Step=230/282 loss=0.0321 acc=0.9938 lr=0.000050 step/sec=4.68 | ETA 00:03:43
[2021-04-29 11:28:58,021] [   TRAIN] - Epoch=6/3, Step=240/282 loss=0.0539 acc=0.9875 lr=0.000050 step/sec=4.66 | ETA 00:03:42
[2021-04-29 11:29:00,159] [   TRAIN] - Epoch=6/3, Step=250/282 loss=0.0088 acc=0.9969 lr=0.000050 step/sec=4.68 | ETA 00:03:42
[2021-04-29 11:29:02,286] [   TRAIN] - Epoch=6/3, Step=260/282 loss=0.0123 acc=0.9969 lr=0.000050 step/sec=4.70 | ETA 00:03:41
[2021-04-29 11:29:04,423] [   TRAIN] - Epoch=6/3, Step=270/282 loss=0.0142 acc=0.9938 lr=0.000050 step/sec=4.68 | ETA 00:03:41
[2021-04-29 11:29:06,558] [   TRAIN] - Epoch=6/3, Step=280/282 loss=0.0419 acc=0.9906 lr=0.000050 step/sec=4.68 | ETA 00:03:40
[2021-04-29 11:29:10,128] [    EVAL] - Evaluation on validation dataset: \ - Evaluation on validation dataset: | - Evaluation on validation dataset: / - Evaluation on validation dataset: - - Evaluation on validation dataset: \ - Evaluation on validation dataset: | - Evaluation on validation dataset: / - Evaluation on validation dataset: - - Evaluation on validation dataset: \ - Evaluation on validation dataset: | - Evaluation on validation dataset: / - Evaluation on validation dataset: - - Evaluation on validation dataset: \ - Evaluation on validation dataset: | - Evaluation on validation dataset: / - Evaluation on validation dataset: - - Evaluation on validation dataset: \ - Evaluation on validation dataset: | - Evaluation on validation dataset: / - Evaluation on validation dataset: - - Evaluation on validation dataset: \ - Evaluation on validation dataset: | - Evaluation on validation dataset: / - Evaluation on validation dataset: - - Evaluation on validation dataset: \ - Evaluation on validation dataset: | - Evaluation on validation dataset: / - Evaluation on validation dataset: - - Evaluation on validation dataset: \ - Evaluation on validation dataset: | - Evaluation on validation dataset: / - Evaluation on validation dataset: - - [Evaluation result] avg_acc=0.9186
[2021-04-29 11:29:10,130] [    INFO] - Saving model checkpoint to ./ckpt/epoch_6
trainer.train 主要控制具体的训练过程,包含以下可控制的参数:

train_dataset: 训练时所用的数据集;
epochs: 训练轮数;
batch_size: 训练的批大小,如果使用GPU,请根据实际情况调整batch_size;
num_workers: works的数量,默认为0;
eval_dataset: 验证集;
log_interval: 打印日志的间隔, 单位为执行批训练的次数。
save_interval: 保存模型的间隔频次,单位为执行训练的轮数。
In [9]
result = trainer.evaluate(test_dataset, batch_size=32)    # 在测试集上评估当前训练模型
[2021-04-29 11:29:22,090] [    INFO] - Evaluation on validation dataset: \
[2021-04-29 11:29:25,390] [    EVAL] - Evaluation on validation dataset: | - Evaluation on validation dataset: / - Evaluation on validation dataset: - - Evaluation on validation dataset: \ - Evaluation on validation dataset: | - Evaluation on validation dataset: / - Evaluation on validation dataset: - - Evaluation on validation dataset: \ - Evaluation on validation dataset: | - Evaluation on validation dataset: / - Evaluation on validation dataset: - - Evaluation on validation dataset: \ - Evaluation on validation dataset: | - Evaluation on validation dataset: / - Evaluation on validation dataset: - - Evaluation on validation dataset: \ - Evaluation on validation dataset: | - Evaluation on validation dataset: / - Evaluation on validation dataset: - - Evaluation on validation dataset: \ - Evaluation on validation dataset: | - Evaluation on validation dataset: / - Evaluation on validation dataset: - - Evaluation on validation dataset: \ - Evaluation on validation dataset: | - Evaluation on validation dataset: / - Evaluation on validation dataset: - - Evaluation on validation dataset: \ - Evaluation on validation dataset: | - Evaluation on validation dataset: / - Evaluation on validation dataset: - - [Evaluation result] avg_acc=0.9236
四、使用模型进行预测
当Finetune完成后,我们加载训练后保存的最佳模型来进行预测,完整预测代码如下:

In [10]
# Data to be prdicted
data = [
    # 房产
    ["昌平京基鹭府10月29日推别墅1200万套起享97折  新浪房产讯(编辑郭彪)京基鹭府(论坛相册户型样板间点评地图搜索)售楼处位于昌平区京承高速北七家出口向西南公里路南。项目预计10月29日开盘,总价1200万元/套起,2012年年底入住。待售户型为联排户型面积为410-522平方米,独栋户型面积为938平方米,双拼户型面积为522平方米。  京基鹭府项目位于昌平定泗路与东北路交界处。项目周边配套齐全,幼儿园:伊顿双语幼儿园、温莎双语幼儿园;中学:北师大亚太实验学校、潞河中学(北京市重点);大学:王府语言学校、北京邮电大学、现代音乐学院;医院:王府中西医结合医院(三级甲等)、潞河医院、解放军263医院、安贞医院昌平分院;购物:龙德广场、中联万家商厦、世纪华联超市、瑰宝购物中心、家乐福超市;酒店:拉斐特城堡、鲍鱼岛;休闲娱乐设施:九华山庄、温都温泉度假村、小汤山疗养院、龙脉温泉度假村、小汤山文化广场、皇港高尔夫、高地高尔夫、北鸿高尔夫球场;银行:工商银行、建设银行、中国银行、北京农村商业银行;邮局:中国邮政储蓄;其它:北七家建材城、百安居建材超市、北七家镇武装部、北京宏翔鸿企业孵化基地等,享受便捷生活。"],
    # 游戏
    ["尽管官方到今天也没有公布《使命召唤:现代战争2》的游戏详情,但《使命召唤:现代战争2》首部包含游戏画面的影片终于现身。虽然影片仅有短短不到20秒,但影片最后承诺大家将于美国时间5月24日NBA职业篮球东区决赛时将会揭露更多的游戏内容。  这部只有18秒的广告片闪现了9个镜头,能够辨识的场景有直升机飞向海岛军事工事,有飞机场争夺战,有潜艇和水下工兵,有冰上乘具,以及其他的一些镜头。整体来看《现代战争2》很大可能仍旧与俄罗斯有关。  片尾有一则预告:“May24th,EasternConferenceFinals”,这是什么?这是说当前美国NBA联赛东部总决赛的日期。原来这部视频是NBA季后赛奥兰多魔术对波士顿凯尔特人队时,TNT电视台播放的广告。"],
    # 体育
    ["罗马锋王竟公然挑战两大旗帜拉涅利的球队到底错在哪  记者张恺报道主场一球小胜副班长巴里无可吹捧,罗马占优也纯属正常,倒是托蒂罚失点球和前两号门将先后受伤(多尼以三号身份出场)更让人揪心。阵容规模扩大,反而表现不如上赛季,缺乏一流强队的色彩,这是所有球迷对罗马的印象。  拉涅利说:“去年我们带着嫉妒之心看国米,今年我们也有了和国米同等的超级阵容,许多教练都想有罗马的球员。阵容广了,寻找队内平衡就难了,某些时段球员的互相排斥和跟从前相比的落差都正常。有好的一面,也有不好的一面,所幸,我们一直在说一支伟大的罗马,必胜的信念和够级别的阵容,我们有了。”拉涅利的总结由近一阶段困扰罗马的队内摩擦、个别球员闹意见要走人而发,本赛季技术层面强化的罗马一直没有上赛季反扑的面貌,内部变化值得球迷关注。"],
    # 教育
    ["新总督致力提高加拿大公立教育质量  滑铁卢大学校长约翰斯顿先生于10月1日担任加拿大总督职务。约翰斯顿先生还曾任麦吉尔大学长,并曾在多伦多大学、女王大学和西安大略大学担任教学职位。  约翰斯顿先生在就职演说中表示,要将加拿大建设成为一个“聪明与关爱的国度”。为实现这一目标,他提出三个支柱:支持并关爱家庭、儿童;鼓励学习与创造;提倡慈善和志愿者精神。他尤其强调要关爱并尊重教师,并通过公立教育使每个人的才智得到充分发展。"]
]

label_list=['体育', '科技', '社会', '娱乐', '股票', '房产', '教育', '时政', '财经', '星座', '游戏', '家居', '彩票', '时尚']
label_map = { 
    idx: label_text for idx, label_text in enumerate(label_list)
}

model = hub.Module(
    name='ernie',
    task='seq-cls',
    load_checkpoint='./ckpt/best_model/model.pdparams',
    label_map=label_map)
results = model.predict(data, max_seq_len=128, batch_size=1, use_gpu=True)
for idx, text in enumerate(data):
    print('Data: {} \t Lable: {}'.format(text[0], results[idx]))
[2021-04-29 11:29:25,406] [    INFO] - Already cached /home/aistudio/.paddlenlp/models/ernie-1.0/ernie_v1_chn_base.pdparams
[2021-04-29 11:29:30,102] [    INFO] - Loaded parameters from /home/aistudio/data/data16287/ckpt/best_model/model.pdparams
[2021-04-29 11:29:30,155] [    INFO] - Found /home/aistudio/.paddlenlp/models/ernie-1.0/vocab.txt
/opt/conda/envs/python35-paddle120-env/lib/python3.7/site-packages/paddle/tensor/creation.py:143: DeprecationWarning: `np.object` is a deprecated alias for the builtin `object`. To silence this warning, use `object` by itself. Doing this will not modify any behavior and is safe. 
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  if data.dtype == np.object:
Data: 昌平京基鹭府10月29日推别墅1200万套起享97折  新浪房产讯(编辑郭彪)京基鹭府(论坛相册户型样板间点评地图搜索)售楼处位于昌平区京承高速北七家出口向西南公里路南。项目预计10月29日开盘,总价1200万元/套起,2012年年底入住。待售户型为联排户型面积为410-522平方米,独栋户型面积为938平方米,双拼户型面积为522平方米。  京基鹭府项目位于昌平定泗路与东北路交界处。项目周边配套齐全,幼儿园:伊顿双语幼儿园、温莎双语幼儿园;中学:北师大亚太实验学校、潞河中学(北京市重点);大学:王府语言学校、北京邮电大学、现代音乐学院;医院:王府中西医结合医院(三级甲等)、潞河医院、解放军263医院、安贞医院昌平分院;购物:龙德广场、中联万家商厦、世纪华联超市、瑰宝购物中心、家乐福超市;酒店:拉斐特城堡、鲍鱼岛;休闲娱乐设施:九华山庄、温都温泉度假村、小汤山疗养院、龙脉温泉度假村、小汤山文化广场、皇港高尔夫、高地高尔夫、北鸿高尔夫球场;银行:工商银行、建设银行、中国银行、北京农村商业银行;邮局:中国邮政储蓄;其它:北七家建材城、百安居建材超市、北七家镇武装部、北京宏翔鸿企业孵化基地等,享受便捷生活。 	 Lable: 房产
Data: 尽管官方到今天也没有公布《使命召唤:现代战争2》的游戏详情,但《使命召唤:现代战争2》首部包含游戏画面的影片终于现身。虽然影片仅有短短不到20秒,但影片最后承诺大家将于美国时间5月24日NBA职业篮球东区决赛时将会揭露更多的游戏内容。  这部只有18秒的广告片闪现了9个镜头,能够辨识的场景有直升机飞向海岛军事工事,有飞机场争夺战,有潜艇和水下工兵,有冰上乘具,以及其他的一些镜头。整体来看《现代战争2》很大可能仍旧与俄罗斯有关。  片尾有一则预告:“May24th,EasternConferenceFinals”,这是什么?这是说当前美国NBA联赛东部总决赛的日期。原来这部视频是NBA季后赛奥兰多魔术对波士顿凯尔特人队时,TNT电视台播放的广告。 	 Lable: 游戏
Data: 罗马锋王竟公然挑战两大旗帜拉涅利的球队到底错在哪  记者张恺报道主场一球小胜副班长巴里无可吹捧,罗马占优也纯属正常,倒是托蒂罚失点球和前两号门将先后受伤(多尼以三号身份出场)更让人揪心。阵容规模扩大,反而表现不如上赛季,缺乏一流强队的色彩,这是所有球迷对罗马的印象。  拉涅利说:“去年我们带着嫉妒之心看国米,今年我们也有了和国米同等的超级阵容,许多教练都想有罗马的球员。阵容广了,寻找队内平衡就难了,某些时段球员的互相排斥和跟从前相比的落差都正常。有好的一面,也有不好的一面,所幸,我们一直在说一支伟大的罗马,必胜的信念和够级别的阵容,我们有了。”拉涅利的总结由近一阶段困扰罗马的队内摩擦、个别球员闹意见要走人而发,本赛季技术层面强化的罗马一直没有上赛季反扑的面貌,内部变化值得球迷关注。 	 Lable: 体育
Data: 新总督致力提高加拿大公立教育质量  滑铁卢大学校长约翰斯顿先生于10月1日担任加拿大总督职务。约