PyTorch新闻文本分类教程

介绍

在本篇教程中,我将向你介绍如何使用PyTorch进行新闻文本分类。我将指导你完成以下步骤:

  1. 数据准备:下载和预处理数据集
  2. 构建模型:定义一个用于文本分类的深度学习模型
  3. 训练模型:使用数据训练模型
  4. 测试模型:评估模型在测试数据上的性能
  5. 预测:使用训练好的模型对新的文本进行分类

数据准备

在进行文本分类之前,我们需要准备一个适合的数据集。在本教程中,我们将使用一个名为AG News的公开数据集。该数据集包含了数万条新闻标题和对应的类别标签。

首先,我们需要下载数据集。可以使用以下代码来下载和解压数据集:

import os
import urllib.request
import tarfile

# 创建存储数据集的文件夹
data_folder = 'data'
os.makedirs(data_folder, exist_ok=True)

# 下载数据集压缩文件
url = '
filename = os.path.join(data_folder, 'ag_news_csv.tar.gz')
urllib.request.urlretrieve(url, filename)

# 解压数据集文件
with tarfile.open(filename, 'r:gz') as tar:
    tar.extractall(data_folder)

完成数据集的下载和解压后,我们需要加载数据并进行预处理。在PyTorch中,我们可以使用torchtext库来处理文本数据。下面的代码展示了如何使用torchtext加载数据集并进行预处理:

import torchtext
from torchtext.data.utils import get_tokenizer
from torchtext.datasets import text_classification

# 设置随机种子
torch.manual_seed(1234)

# 定义文本的tokenizer
tokenizer = get_tokenizer('basic_english')

# 加载数据集
train_dataset, test_dataset = text_classification.DATASETS['AG_NEWS'](
    root=data_folder, ngrams=2, vocab=None, tokenizer=tokenizer)

# 打印数据集信息
print(f'训练集大小:{len(train_dataset.examples)}')
print(f'测试集大小:{len(test_dataset.examples)}')

该代码块中,我们使用了text_classification.DATASETS中的AG_NEWS数据集,设置了ngrams为2,即每个词汇由1-2个单词组成。可以根据需要调整ngrams的值。

构建模型

在本教程中,我们将使用卷积神经网络(CNN)作为模型进行文本分类。CNN在计算机视觉任务中取得了很好的效果,现在也被广泛应用于自然语言处理任务。

我们可以使用PyTorch的nn.Module类来定义我们的模型。下面的代码展示了如何定义一个简单的文本分类模型:

import torch.nn as nn
import torch.nn.functional as F

class TextCNN(nn.Module):
    def __init__(self, vocab_size, embedding_dim, num_classes, num_filters, filter_sizes):
        super(TextCNN, self).__init__()
        self.embedding = nn.EmbeddingBag(vocab_size, embedding_dim)
        self.convs = nn.ModuleList([
            nn.Conv2d(1, num_filters, (fs, embedding_dim)) for fs in filter_sizes
        ])
        self.fc = nn.Linear(len(filter_sizes) * num_filters, num_classes)

    def forward(self, text):
        embedded = self.embedding(text)
        embedded = embedded.unsqueeze(1)
        conved = [F.relu(conv(embedded)).squeeze(3) for conv in self.convs]
        pooled = [F.max_pool1d(conv, conv.size(2)).squeeze(2) for conv in conved]
        cat = torch.cat(pooled, dim=1)
        output = self.fc(cat)
        return output

上述代码中,我们使用了一个embedding层,该层将文本转化为固定长度的向量表示。接下来,我们使用多个不同大小的卷积核对文本进行卷积操作,并使用ReLU激活函数。然后,我们使用最大池化层对每个卷积操作的输出进行池化。最后,我们通过一个全连接层将所有卷积操作的结果进行分类