博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
【PyTorch】7 文本分类TorchText实战——AG_NEWS四类别新闻分类
阅读量:3915 次
发布时间:2019-05-23

本文共 17034 字,大约阅读时间需要 56 分钟。

使用 TorchText 进行文本分类

这是官方文本篇的一个教程,原,,原,介绍了如何使用torchtext中的文本分类数据集,本文是其详细的注解,关于TorchText API的,参考和博客

本示例说明了如何使用这些TextClassification数据集之一训练用于分类的监督学习算法

ngrams功能用于捕获有关本地单词顺序的一些部分信息。 在实践中,应用二元语法或三元语法作为单词组比仅仅一个单词提供更多的好处。 一个例子:

"load data with ngrams"Bi-grams results: "load data", "data with", "with ngrams"Tri-grams results: "load data with", "data with ngrams"

TextClassification数据集支持 ngrams 方法。 通过将 ngrams 设置为 2,数据集中的示例文本将是一个单字加 bi-grams 字符串的列表

输入以下代码进行安装:

pip install torchtext

原文的这个from torchtext.datasets import text_classification代码是错的,而且text_classification.DATASETS['AG_NEWS']的参数都变了,详见英文手册

1.访问原始数据集迭代器

torchtext 库提供了一些原始数据集迭代器,这些迭代器产生原始文本字符串。例如,AG_NEWS数据集迭代器产生的原始数据是标签和文本的元组

使用此函数时train_data, test_dataset = AG_NEWS(root=path, split=('train', 'test'))会报错:

TimeoutError: [WinError 10060] A connection attempt failed because the connected party did not properly respond after a period of time, or established connection failed because connected host has failed to respond

这里直接打开url进行下载:

URL = {
'train': "https://raw.githubusercontent.com/mhjabreel/CharCnn_Keras/master/data/ag_news_csv/train.csv", 'test': "https://raw.githubusercontent.com/mhjabreel/CharCnn_Keras/master/data/ag_news_csv/test.csv",}
from torchtext.datasets import AG_NEWSpath = '... your path\\AG_NEWS.data'train_data, test_dataset = AG_NEWS(root=path, split=('train', 'test'))print(next(train_data))print(next(train_data))
(3, "Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\\band of ultra-cynics, are seeing green again.")(3, 'Carlyle Looks Toward Commercial Aerospace (Reuters) Reuters - Private investment firm Carlyle Group,\\which has a reputation for making well-timed and occasionally\\controversial plays in the defense industry, has quietly placed\\its bets on another part of the market.')

2. 准备数据处理管道

我们已经重新审视了torchtext库中最基本的组件,包括vocab、单词向量、tokenizer。这些都是原始文本字符串的基本数据处理构件

这里是一个典型的NLP数据处理的例子,使用tokenizer和词汇。第一步是用原始训练数据集建立一个词汇表,用户可以通过在Vocab类的构造函数中设置参数来拥有一个自定义的词汇表。用户可以通过在Vocab类的构造函数中设置参数来拥有一个自定义的词汇表。例如,要包含的令牌的最小频率min_freq

对于函数lambda,此表达式是一种匿名函数,对应python中的自定义函数def

词汇块将一个tokens列表转换成整数

[vocab[token] for token in ['here', 'is', 'an', 'example']]>>> [476, 22, 31, 5298]

用标记器和词汇准备文本处理管道。文本和标签流水线将用于处理来自数据集迭代器的原始数据字符串

文本流水线根据词汇表中定义的查找表将文本字符串转换为整数列表。标签流水线将标签转换为整数。例如:

text_pipeline('here is the an example')>>> [475, 21, 2, 30, 5286]label_pipeline('10')>>> 9

3. 生成数据批次和迭代器

torch.utils.data.DataLoader 推荐给 PyTorch 用户使用(教程在)。它适用于实现 getitem()len()协议的地图式数据集,并表示从索引/键到数据样本的映射。它也适用于shuffle argumnent为False的可迭代数据集

在发送至模型之前, collate_fn 函数对 DataLoader 中生成的一批样本进行处理。collate_fn的输入是DataLoader中批量大小的数据, collate_fn根据之前声明的数据处理管道对它们进行处理。这里要注意,一定要将 collate_fn 声明为顶层 def,这样才能保证该函数在每个 worker 中都能使用

在这个例子中,原始数据批输入中的文本条目被打包成一个列表,并作为一个单一的张量来连接nn.EmbeddingBag的输入。偏移量是一个定界符的张量,用于表示文本张量中各个序列的起始索引。Label是一个张量,保存了indidividual文本条目的标签

关于torch.cumsum()函数的用法:

x = torch.arange(0, 6).view(2, 3)print(x)print(x.cumsum(dim=0))print(x.cumsum(dim=1))
tensor([[0, 1, 2],        [3, 4, 5]])tensor([[0, 1, 2],        [3, 5, 7]])tensor([[ 0,  1,  3],        [ 3,  7, 12]])

个人理解collate_fn是从样本列表中过来了一个batch的数据,经过映射函数,形成一个tensor

4. 定义模型

该模型由nn.EmbeddingBag层加上一个线性层组成,以达到分类的目的。nn.EmbeddingBag默认模式为 “mean”,计算一个 "袋 "的嵌入物的平均值。虽然这里的文本条目有不同的长度,但由于文本长度是以偏移量保存的,所以nn.EmbeddingBag模块在这里不需要填充

另外,由于nn.EmbeddingBag会动态累积嵌入中的平均值,因此nn.EmbeddingBag可以提高性能和存储效率,以处理张量序列

在这里插入图片描述

关于EmbeddingBag()函数,,参考此,参数只多了一个:mode,来看这个参数的取值有三种,对应三种操作:"sum"表示普通embedding后接torch.sum(dim=0),"mean"相当于后接torch.mean(dim=0),"max"相当于后接torch.max(dim=0)

此网络输入输出的例子:

>>> # an Embedding module containing 10 tensors of size 3>>> embedding_sum = nn.EmbeddingBag(10, 3, mode='sum')>>> # a batch of 2 samples of 4 indices each>>> input = torch.LongTensor([1,2,4,5,4,3,2,9])>>> offsets = torch.LongTensor([0,4])>>> embedding_sum(input, offsets)tensor([[-0.8861, -5.4350, -0.0523],        [ 1.1306, -2.5798, -1.0044]])

5. 初始化一个实例

AG_NEWS数据集有四个标签,因此类的数量是四个:

1 : World2 : Sports3 : Business4 : Sci/Tec

我们建立一个嵌入维度为64的模型,vocab大小等于词汇实例的长度,类的数量等于标签的数量4

6. 定义训练模型和评估结果的函数

关于调整学习率,,函数:torch.optim.lr_scheduler提供了几种方法来调整基于epochs的学习率

torch.optim.lr_scheduler.StepLR每隔一个step_size epochs,将每个参数组的学习率按gamma衰减。请注意,这种衰减可以与其他来自这个调度器外部的学习率变化同时发生。当last_epoch=-1时,设置初始lr为lr

关于torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)函数,作用是剪切参数迭代的梯度法线,,法线是在所有梯度上一起计算的,就像它们被连成一个向量一样。梯度是就地修改的,即:梯度剪切,规定了最大不能超过的max_norm

对于每一个batch预测的predited_label,是一个64*4的tensor,对于每一个label,是一个64的一维的tensor

tensor([[ 0.4427,  0.0830,  0.0109,  0.1273],        [ 0.1601,  0.0869, -0.0540,  0.0422],        ...
tensor([0, 0, 0, 3, 1, 1, 1, 3, 3, 3, 3, 3, 1, 1, 3, 1, 1, 3, 3, 3, 1, 1, 3, 3,        3, 1, 1, 2, 1, 2, 1, 1, 3, 3, 1, 1, 1, 3, 1, 3, 0, 1, 0, 0, 1, 3, 3, 3,        2, 3, 1, 3, 3, 3, 1, 3, 3, 1, 1, 2, 0, 2, 1, 3])

之前我们用的是.topk()函数,这里了解一下.argmax(1)函数:

print(predited_label.argmax(1) == label)
tensor([False,  True,  True,  True, False,  True,  True,  True,  True,  True,         True, False,  True,  True,  True, False,  True,  True,  True,  True,         True,  True,  True,  True,  True,  True,  True,  True,  True,  True,         True, False,  True, False,  True,  True,  True,  True, False,  True,         True,  True, False,  True, False,  True,  True, False,  True,  True,         True, False, False,  True,  True, False,  True, False, False,  True,        False,  True,  True,  True])

执行以下代码输出就是一个常数:

(predited_label.argmax(1) == label).sum().item()

7. 拆分数据集并运行模型

由于原AG_NEWS没有有效数据集,我们将训练数据集拆分为训练/有效集,拆分比例为0.95(训练)和0.05(有效)。这里我们使用PyTorch核心库中的torch.utils.data.dataset.random_split函数

CrossEntropyLoss准则将nn.LogSoftmax()和nn.NLLLoss()结合在一个类中。它在训练C类的分类问题时非常有用。SGD实现了随机梯度下降法作为优化器。初始学习率设置为5.0。这里使用StepLR通过epochs来调整学习率

打印训练过程:

| epoch   1 |   500/ 1782 batches, accuracy    0.685| epoch   1 |  1000/ 1782 batches, accuracy    0.852| epoch   1 |  1500/ 1782 batches, accuracy    0.876-----------------------------------------------------------| end of epoch   1 | time: 15.24s | valid accuracy    0.886 -----------------------------------------------------------| epoch   2 |   500/ 1782 batches, accuracy    0.896| epoch   2 |  1000/ 1782 batches, accuracy    0.902| epoch   2 |  1500/ 1782 batches, accuracy    0.902-----------------------------------------------------------| end of epoch   2 | time: 15.20s | valid accuracy    0.899 -----------------------------------------------------------| epoch   3 |   500/ 1782 batches, accuracy    0.915| epoch   3 |  1000/ 1782 batches, accuracy    0.914| epoch   3 |  1500/ 1782 batches, accuracy    0.915-----------------------------------------------------------| end of epoch   3 | time: 15.22s | valid accuracy    0.904 -----------------------------------------------------------| epoch   4 |   500/ 1782 batches, accuracy    0.924| epoch   4 |  1000/ 1782 batches, accuracy    0.924| epoch   4 |  1500/ 1782 batches, accuracy    0.923-----------------------------------------------------------| end of epoch   4 | time: 15.16s | valid accuracy    0.908 -----------------------------------------------------------| epoch   5 |   500/ 1782 batches, accuracy    0.930| epoch   5 |  1000/ 1782 batches, accuracy    0.929| epoch   5 |  1500/ 1782 batches, accuracy    0.931-----------------------------------------------------------| end of epoch   5 | time: 15.21s | valid accuracy    0.900 -----------------------------------------------------------| epoch   6 |   500/ 1782 batches, accuracy    0.943| epoch   6 |  1000/ 1782 batches, accuracy    0.941| epoch   6 |  1500/ 1782 batches, accuracy    0.944-----------------------------------------------------------| end of epoch   6 | time: 15.17s | valid accuracy    0.911 -----------------------------------------------------------| epoch   7 |   500/ 1782 batches, accuracy    0.943| epoch   7 |  1000/ 1782 batches, accuracy    0.945| epoch   7 |  1500/ 1782 batches, accuracy    0.946-----------------------------------------------------------| end of epoch   7 | time: 15.24s | valid accuracy    0.912 -----------------------------------------------------------| epoch   8 |   500/ 1782 batches, accuracy    0.945| epoch   8 |  1000/ 1782 batches, accuracy    0.944| epoch   8 |  1500/ 1782 batches, accuracy    0.944-----------------------------------------------------------| end of epoch   8 | time: 15.20s | valid accuracy    0.913 -----------------------------------------------------------| epoch   9 |   500/ 1782 batches, accuracy    0.944| epoch   9 |  1000/ 1782 batches, accuracy    0.948| epoch   9 |  1500/ 1782 batches, accuracy    0.946-----------------------------------------------------------| end of epoch   9 | time: 15.29s | valid accuracy    0.915 -----------------------------------------------------------| epoch  10 |   500/ 1782 batches, accuracy    0.949| epoch  10 |  1000/ 1782 batches, accuracy    0.945| epoch  10 |  1500/ 1782 batches, accuracy    0.946-----------------------------------------------------------| end of epoch  10 | time: 15.19s | valid accuracy    0.913 -----------------------------------------------------------Checking the results of test dataset.test accuracy    0.908

对于这样一个句子:

"MEMPHIS, Tenn. – Four days ago, Jon Rahm was enduring the season’s worst weather conditions on Sunday at The Open on his way to a closing 75 at Royal Portrush, which considering the wind and the rain was a respectable showing. Thursday’s first round at the WGC-FedEx St. Jude Invitational was another story. With temperatures in the mid-80s and hardly any wind, the Spaniard was 13 strokes better in a flawless round. Thanks to his best putting performance on the PGA Tour, Rahm finished with an 8-under 62 for a three-stroke lead, which was even more impressive considering he’d never played the front nine at TPC Southwind."

输出结果:

This is a Sports news

对于这样一个句子:

'Beijing of Automation, Beijing Institute of Technology'

输出结果:

This is a Sci/Tec news

可以发现分类结果还是比较理想的

8. 全部代码

path = '... your path\\AG_NEWS.data'import torchfrom torchtext.datasets import AG_NEWStrain_iter = AG_NEWS(root=path, split='train')      # 访问原始数据集迭代器from torchtext.data.utils import get_tokenizerfrom collections import Counterfrom torchtext.vocab import Vocabtokenizer = get_tokenizer('basic_english')      # 输入的字符串counter = Counter()for (label, line) in train_iter:    counter.update(tokenizer(line))vocab = Vocab(counter, min_freq=1)# 准备数据处理管道text_pipeline = lambda x: [vocab[token] for token in tokenizer(x)]      # token就是word,vocab[token]就是其对应的数字label_pipeline = lambda x: int(x) - 1       # 把1、2、3、4 转化为 0、1、2、3 四类# 生成数据批次和迭代器from torch.utils.data import DataLoaderdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")def collate_batch(batch):    label_list, text_list, offsets = [], [], [0]    for (_label, _text) in batch:        label_list.append(label_pipeline(_label))        processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)      # torch.Size([41]), torch.Size([58])...        text_list.append(processed_text)        offsets.append(processed_text.size(0))    label_list = torch.tensor(label_list, dtype=torch.int64)        # torch.Size([64])    offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)      # torch.Size([64])    text_list = torch.cat(text_list)        # 若干tensor组成的列表变成一个tensor    return label_list.to(device), text_list.to(device), offsets.to(device)# dataloader = DataLoader(train_iter, batch_size=8, shuffle=False, collate_fn=collate_batch)# import ipdbfrom torch import nnclass TextClassificationModel(nn.Module):    def __init__(self, vocab_size, embed_dim, num_class):        super(TextClassificationModel, self).__init__()        self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True)        self.fc = nn.Linear(embed_dim, num_class)        self.init_weights()    def init_weights(self):        initrange = 0.5        self.embedding.weight.data.uniform_(-initrange, initrange)      # 将tensor用从均匀分布中抽样得到的值填充        self.fc.weight.data.uniform_(-initrange, initrange)        self.fc.bias.data.zero_()    def forward(self, text, offsets):        embedded = self.embedding(text, offsets)        # torch.Size([64, 64])        output = self.fc(embedded)      # torch.Size([64, 4])        return output# num_class = len(set([label for (label, text) in train_iter]))       # 迭代器需要重新开始才能计算...即train_iter = AG_NEWS(root=path, split='train')      # 访问原始数据集迭代器num_class = 4vocab_size = len(vocab)emsize = 64model = TextClassificationModel(vocab_size, emsize, num_class).to(device)import timedef train(dataloader):    model.train()       # 训练模式    total_acc, total_count = 0, 0    log_interval = 500    start_time = time.time()    for idx, (label, text, offsets) in enumerate(dataloader):        optimizer.zero_grad()        predited_label = model(text, offsets)        loss = criterion(predited_label, label)        loss.backward()        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)     # 规定了最大不能超过的max_norm        optimizer.step()        total_acc += (predited_label.argmax(1) == label).sum().item()        total_count += label.size(0)        if idx % log_interval == 0 and idx > 0:            elapsed = time.time() - start_time            print('| epoch {:3d} | {:5d}/{:5d} batches, accuracy {:8.3f}'.format(epoch, idx, len(dataloader), total_acc / total_count))            total_acc, total_count = 0, 0            start_time = time.time()def evaluate(dataloader):    model.eval()    total_acc, total_count = 0, 0    with torch.no_grad():        for idx, (label, text, offsets) in enumerate(dataloader):            predited_label = model(text, offsets)            # loss = criterion(predited_label, label)            total_acc += (predited_label.argmax(1) == label).sum().item()            total_count += label.size(0)    return total_acc / total_countdef predict(text, text_pipeline):    with torch.no_grad():        text = torch.tensor(text_pipeline(text))        output = model(text, torch.tensor([0]))        return output.argmax(1).item() + 1from torch.utils.data.dataset import random_splitif __name__ == '__main__':    # 超参数(Hyperparameters)    # EPOCHS = 10  # epoch    # LR = 5  # learning rate    # BATCH_SIZE = 64  # batch size for training    #    # criterion = torch.nn.CrossEntropyLoss()    # optimizer = torch.optim.SGD(model.parameters(), lr=LR)    # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)    # total_accu = None    # train_iter, test_iter = AG_NEWS(root=path)    # train_dataset = list(train_iter)    # test_dataset = list(test_iter)    # num_train = int(len(train_dataset) * 0.95)    # split_train_, split_valid_ = random_split(train_dataset, [num_train, len(train_dataset) - num_train])    #    # train_dataloader = DataLoader(split_train_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)      # shuffle表示随机打乱    # valid_dataloader = DataLoader(split_valid_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)    # test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch)    #    # for epoch in range(1, EPOCHS + 1):    #     epoch_start_time = time.time()    #     train(train_dataloader)    #     accu_val = evaluate(valid_dataloader)    #     if total_accu is not None and total_accu > accu_val:    #         scheduler.step()    #     else:    #         total_accu = accu_val    #     print('-' * 59)    #     print('| end of epoch {:3d} | time: {:5.2f}s | '    #           'valid accuracy {:8.3f} '.format(epoch, time.time() - epoch_start_time, accu_val))    #     print('-' * 59)    #    #    # print('Checking the results of test dataset.')    # accu_test = evaluate(test_dataloader)    # print('test accuracy {:8.3f}'.format(accu_test))    #    # torch.save(model.state_dict(), '... your path\\model_TextClassification.pth')    # 以下是评估    model.load_state_dict(torch.load('... your path\\model_TextClassification.pth'))    ag_news_label = {
1: "World", 2: "Sports", 3: "Business", 4: "Sci/Tec"} # ex_text_str = "MEMPHIS, Tenn. – Four days ago, Jon Rahm was enduring the season’s worst weather conditions on Sunday at The Open on his way to a closing 75 at Royal Portrush, which considering the wind and the rain was a respectable showing. Thursday’s first round at the WGC-FedEx St. Jude Invitational was another story. With temperatures in the mid-80s and hardly any wind, the Spaniard was 13 strokes better in a flawless round. Thanks to his best putting performance on the PGA Tour, Rahm finished with an 8-under 62 for a three-stroke lead, which was even more impressive considering he’d never played the front nine at TPC Southwind." ex_text_str = 'Beijing of Automation, Beijing Institute of Technology' # model = model.to("cpu") print("This is a %s news" % ag_news_label[predict(ex_text_str, text_pipeline)])

小结

  1. 数据集的获取踩了一些坑,首先是中文教材是错的,没有及时更新,还是得去看英文的;以及下次时github又下不动,用IDM才能完成下载……
  2. 数据通道准备实际上就是英文单词的one-hot模型
  3. 数据批次和迭代器的DataLoader应该很重要,它能把数据转化成流式来处理,避免全部读进来,内存直接爆掉;collate_fn这种将batch变成tensor第一次接触有点难懂
  4. 模型比较简单,就是每个单词embedding之后取个平均来表示一个句子
  5. 训练时有个更新学习率的操作,可以借鉴一下;它做了个验证集感觉也没什么用……

未来工作:

  1. 另外一个TorchText的实验代码复现一下
  2. 学习BERT、Transformer模型,编程实现

转载地址:http://dwtrn.baihongyu.com/

你可能感兴趣的文章
Java:类与继承
查看>>
深入理解Java:String
查看>>
Java异常处理和设计
查看>>
Java设计模式之观察者模式
查看>>
一位资深程序员大牛给予Java初学者的学习路线建议
查看>>
浅谈Java中的hashcode方法
查看>>
Java NIO:NIO概述
查看>>
Java中的static关键字解析
查看>>
Java多态性理解
查看>>
Java IO流学习总结
查看>>
Java发送http的get、post请求
查看>>
Java中volatile关键字的含义
查看>>
给Java程序猿们推荐一些值得一看的好书
查看>>
Java这些冷知识你知道吗?
查看>>
假如时光倒流,我会这么学习Java
查看>>
一位10年Java工作经验的架构师聊Java和工作经验
查看>>
Java项目经验——程序员成长的钥匙
查看>>
假如时光倒流,我会这么学习Java
查看>>
一位资深程序员大牛给予Java初学者的学习路线建议
查看>>
Java后端2017书单推荐
查看>>