本文共 17034 字,大约阅读时间需要 56 分钟。
这是官方文本篇的一个教程,原,,原,介绍了如何使用torchtext中的文本分类数据集,本文是其详细的注解,关于TorchText API的,参考和博客
本示例说明了如何使用这些TextClassification
数据集之一训练用于分类的监督学习算法
ngrams功能用于捕获有关本地单词顺序的一些部分信息。 在实践中,应用二元语法或三元语法作为单词组比仅仅一个单词提供更多的好处。 一个例子:
"load data with ngrams"Bi-grams results: "load data", "data with", "with ngrams"Tri-grams results: "load data with", "data with ngrams"
TextClassification
数据集支持 ngrams 方法。 通过将 ngrams 设置为 2,数据集中的示例文本将是一个单字加 bi-grams 字符串的列表
输入以下代码进行安装:
pip install torchtext
原文的这个from torchtext.datasets import text_classification
代码是错的,而且text_classification.DATASETS['AG_NEWS']
的参数都变了,详见英文手册
torchtext 库提供了一些原始数据集迭代器,这些迭代器产生原始文本字符串。例如,AG_NEWS数据集迭代器产生的原始数据是标签和文本的元组
使用此函数时train_data, test_dataset = AG_NEWS(root=path, split=('train', 'test'))
会报错:
TimeoutError: [WinError 10060] A connection attempt failed because the connected party did not properly respond after a period of time, or established connection failed because connected host has failed to respond
这里直接打开url进行下载:
URL = { 'train': "https://raw.githubusercontent.com/mhjabreel/CharCnn_Keras/master/data/ag_news_csv/train.csv", 'test': "https://raw.githubusercontent.com/mhjabreel/CharCnn_Keras/master/data/ag_news_csv/test.csv",}
from torchtext.datasets import AG_NEWSpath = '... your path\\AG_NEWS.data'train_data, test_dataset = AG_NEWS(root=path, split=('train', 'test'))print(next(train_data))print(next(train_data))
(3, "Wall St. Bears Claw Back Into the Black (Reuters) Reuters - Short-sellers, Wall Street's dwindling\\band of ultra-cynics, are seeing green again.")(3, 'Carlyle Looks Toward Commercial Aerospace (Reuters) Reuters - Private investment firm Carlyle Group,\\which has a reputation for making well-timed and occasionally\\controversial plays in the defense industry, has quietly placed\\its bets on another part of the market.')
我们已经重新审视了torchtext库中最基本的组件,包括vocab、单词向量、tokenizer。这些都是原始文本字符串的基本数据处理构件
这里是一个典型的NLP数据处理的例子,使用tokenizer和词汇。第一步是用原始训练数据集建立一个词汇表,用户可以通过在Vocab类的构造函数中设置参数来拥有一个自定义的词汇表。用户可以通过在Vocab类的构造函数中设置参数来拥有一个自定义的词汇表。例如,要包含的令牌的最小频率min_freq
对于函数lambda
,此表达式是一种匿名函数,对应python中的自定义函数def
词汇块将一个tokens列表转换成整数
[vocab[token] for token in ['here', 'is', 'an', 'example']]>>> [476, 22, 31, 5298]
用标记器和词汇准备文本处理管道。文本和标签流水线将用于处理来自数据集迭代器的原始数据字符串
文本流水线根据词汇表中定义的查找表将文本字符串转换为整数列表。标签流水线将标签转换为整数。例如:
text_pipeline('here is the an example')>>> [475, 21, 2, 30, 5286]label_pipeline('10')>>> 9
torch.utils.data.DataLoader
推荐给 PyTorch 用户使用(教程在)。它适用于实现 getitem()
和 len()
协议的地图式数据集,并表示从索引/键到数据样本的映射。它也适用于shuffle argumnent为False的可迭代数据集
在发送至模型之前, collate_fn 函数对 DataLoader 中生成的一批样本进行处理。collate_fn的输入是DataLoader中批量大小的数据, collate_fn根据之前声明的数据处理管道对它们进行处理。这里要注意,一定要将 collate_fn 声明为顶层 def,这样才能保证该函数在每个 worker 中都能使用
在这个例子中,原始数据批输入中的文本条目被打包成一个列表,并作为一个单一的张量来连接nn.EmbeddingBag的输入。偏移量是一个定界符的张量,用于表示文本张量中各个序列的起始索引。Label是一个张量,保存了indidividual文本条目的标签
关于torch.cumsum()
函数的用法:
x = torch.arange(0, 6).view(2, 3)print(x)print(x.cumsum(dim=0))print(x.cumsum(dim=1))
tensor([[0, 1, 2], [3, 4, 5]])tensor([[0, 1, 2], [3, 5, 7]])tensor([[ 0, 1, 3], [ 3, 7, 12]])
个人理解collate_fn
是从样本列表中过来了一个batch的数据,经过映射函数,形成一个tensor
该模型由nn.EmbeddingBag层加上一个线性层组成,以达到分类的目的。nn.EmbeddingBag默认模式为 “mean”,计算一个 "袋 "的嵌入物的平均值。虽然这里的文本条目有不同的长度,但由于文本长度是以偏移量保存的,所以nn.EmbeddingBag模块在这里不需要填充
另外,由于nn.EmbeddingBag会动态累积嵌入中的平均值,因此nn.EmbeddingBag可以提高性能和存储效率,以处理张量序列
关于EmbeddingBag()
函数,,参考此,参数只多了一个:mode,来看这个参数的取值有三种,对应三种操作:"sum"表示普通embedding后接torch.sum(dim=0),"mean"相当于后接torch.mean(dim=0),"max"相当于后接torch.max(dim=0) 此网络输入输出的例子:
>>> # an Embedding module containing 10 tensors of size 3>>> embedding_sum = nn.EmbeddingBag(10, 3, mode='sum')>>> # a batch of 2 samples of 4 indices each>>> input = torch.LongTensor([1,2,4,5,4,3,2,9])>>> offsets = torch.LongTensor([0,4])>>> embedding_sum(input, offsets)tensor([[-0.8861, -5.4350, -0.0523], [ 1.1306, -2.5798, -1.0044]])
AG_NEWS数据集有四个标签,因此类的数量是四个:
1 : World2 : Sports3 : Business4 : Sci/Tec
我们建立一个嵌入维度为64的模型,vocab大小等于词汇实例的长度,类的数量等于标签的数量4
关于调整学习率,,函数:torch.optim.lr_scheduler
提供了几种方法来调整基于epochs的学习率
torch.optim.lr_scheduler.StepLR
每隔一个step_size epochs,将每个参数组的学习率按gamma衰减。请注意,这种衰减可以与其他来自这个调度器外部的学习率变化同时发生。当last_epoch=-1时,设置初始lr为lr
关于torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
函数,作用是剪切参数迭代的梯度法线,,法线是在所有梯度上一起计算的,就像它们被连成一个向量一样。梯度是就地修改的,即:梯度剪切,规定了最大不能超过的max_norm
对于每一个batch预测的predited_label
,是一个64*4的tensor,对于每一个label
,是一个64的一维的tensor
tensor([[ 0.4427, 0.0830, 0.0109, 0.1273], [ 0.1601, 0.0869, -0.0540, 0.0422], ...
tensor([0, 0, 0, 3, 1, 1, 1, 3, 3, 3, 3, 3, 1, 1, 3, 1, 1, 3, 3, 3, 1, 1, 3, 3, 3, 1, 1, 2, 1, 2, 1, 1, 3, 3, 1, 1, 1, 3, 1, 3, 0, 1, 0, 0, 1, 3, 3, 3, 2, 3, 1, 3, 3, 3, 1, 3, 3, 1, 1, 2, 0, 2, 1, 3])
之前我们用的是.topk()函数,这里了解一下.argmax(1)
函数:
print(predited_label.argmax(1) == label)
tensor([False, True, True, True, False, True, True, True, True, True, True, False, True, True, True, False, True, True, True, True, True, True, True, True, True, True, True, True, True, True, True, False, True, False, True, True, True, True, False, True, True, True, False, True, False, True, True, False, True, True, True, False, False, True, True, False, True, False, False, True, False, True, True, True])
执行以下代码输出就是一个常数:
(predited_label.argmax(1) == label).sum().item()
由于原AG_NEWS没有有效数据集,我们将训练数据集拆分为训练/有效集,拆分比例为0.95(训练)和0.05(有效)。这里我们使用PyTorch核心库中的torch.utils.data.dataset.random_split函数
CrossEntropyLoss准则将nn.LogSoftmax()和nn.NLLLoss()结合在一个类中。它在训练C类的分类问题时非常有用。SGD实现了随机梯度下降法作为优化器。初始学习率设置为5.0。这里使用StepLR通过epochs来调整学习率
打印训练过程:
| epoch 1 | 500/ 1782 batches, accuracy 0.685| epoch 1 | 1000/ 1782 batches, accuracy 0.852| epoch 1 | 1500/ 1782 batches, accuracy 0.876-----------------------------------------------------------| end of epoch 1 | time: 15.24s | valid accuracy 0.886 -----------------------------------------------------------| epoch 2 | 500/ 1782 batches, accuracy 0.896| epoch 2 | 1000/ 1782 batches, accuracy 0.902| epoch 2 | 1500/ 1782 batches, accuracy 0.902-----------------------------------------------------------| end of epoch 2 | time: 15.20s | valid accuracy 0.899 -----------------------------------------------------------| epoch 3 | 500/ 1782 batches, accuracy 0.915| epoch 3 | 1000/ 1782 batches, accuracy 0.914| epoch 3 | 1500/ 1782 batches, accuracy 0.915-----------------------------------------------------------| end of epoch 3 | time: 15.22s | valid accuracy 0.904 -----------------------------------------------------------| epoch 4 | 500/ 1782 batches, accuracy 0.924| epoch 4 | 1000/ 1782 batches, accuracy 0.924| epoch 4 | 1500/ 1782 batches, accuracy 0.923-----------------------------------------------------------| end of epoch 4 | time: 15.16s | valid accuracy 0.908 -----------------------------------------------------------| epoch 5 | 500/ 1782 batches, accuracy 0.930| epoch 5 | 1000/ 1782 batches, accuracy 0.929| epoch 5 | 1500/ 1782 batches, accuracy 0.931-----------------------------------------------------------| end of epoch 5 | time: 15.21s | valid accuracy 0.900 -----------------------------------------------------------| epoch 6 | 500/ 1782 batches, accuracy 0.943| epoch 6 | 1000/ 1782 batches, accuracy 0.941| epoch 6 | 1500/ 1782 batches, accuracy 0.944-----------------------------------------------------------| end of epoch 6 | time: 15.17s | valid accuracy 0.911 -----------------------------------------------------------| epoch 7 | 500/ 1782 batches, accuracy 0.943| epoch 7 | 1000/ 1782 batches, accuracy 0.945| epoch 7 | 1500/ 1782 batches, accuracy 0.946-----------------------------------------------------------| end of epoch 7 | time: 15.24s | valid accuracy 0.912 -----------------------------------------------------------| epoch 8 | 500/ 1782 batches, accuracy 0.945| epoch 8 | 1000/ 1782 batches, accuracy 0.944| epoch 8 | 1500/ 1782 batches, accuracy 0.944-----------------------------------------------------------| end of epoch 8 | time: 15.20s | valid accuracy 0.913 -----------------------------------------------------------| epoch 9 | 500/ 1782 batches, accuracy 0.944| epoch 9 | 1000/ 1782 batches, accuracy 0.948| epoch 9 | 1500/ 1782 batches, accuracy 0.946-----------------------------------------------------------| end of epoch 9 | time: 15.29s | valid accuracy 0.915 -----------------------------------------------------------| epoch 10 | 500/ 1782 batches, accuracy 0.949| epoch 10 | 1000/ 1782 batches, accuracy 0.945| epoch 10 | 1500/ 1782 batches, accuracy 0.946-----------------------------------------------------------| end of epoch 10 | time: 15.19s | valid accuracy 0.913 -----------------------------------------------------------Checking the results of test dataset.test accuracy 0.908
对于这样一个句子:
"MEMPHIS, Tenn. – Four days ago, Jon Rahm was enduring the season’s worst weather conditions on Sunday at The Open on his way to a closing 75 at Royal Portrush, which considering the wind and the rain was a respectable showing. Thursday’s first round at the WGC-FedEx St. Jude Invitational was another story. With temperatures in the mid-80s and hardly any wind, the Spaniard was 13 strokes better in a flawless round. Thanks to his best putting performance on the PGA Tour, Rahm finished with an 8-under 62 for a three-stroke lead, which was even more impressive considering he’d never played the front nine at TPC Southwind."
输出结果:
This is a Sports news
对于这样一个句子:
'Beijing of Automation, Beijing Institute of Technology'
输出结果:
This is a Sci/Tec news
可以发现分类结果还是比较理想的
path = '... your path\\AG_NEWS.data'import torchfrom torchtext.datasets import AG_NEWStrain_iter = AG_NEWS(root=path, split='train') # 访问原始数据集迭代器from torchtext.data.utils import get_tokenizerfrom collections import Counterfrom torchtext.vocab import Vocabtokenizer = get_tokenizer('basic_english') # 输入的字符串counter = Counter()for (label, line) in train_iter: counter.update(tokenizer(line))vocab = Vocab(counter, min_freq=1)# 准备数据处理管道text_pipeline = lambda x: [vocab[token] for token in tokenizer(x)] # token就是word,vocab[token]就是其对应的数字label_pipeline = lambda x: int(x) - 1 # 把1、2、3、4 转化为 0、1、2、3 四类# 生成数据批次和迭代器from torch.utils.data import DataLoaderdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")def collate_batch(batch): label_list, text_list, offsets = [], [], [0] for (_label, _text) in batch: label_list.append(label_pipeline(_label)) processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64) # torch.Size([41]), torch.Size([58])... text_list.append(processed_text) offsets.append(processed_text.size(0)) label_list = torch.tensor(label_list, dtype=torch.int64) # torch.Size([64]) offsets = torch.tensor(offsets[:-1]).cumsum(dim=0) # torch.Size([64]) text_list = torch.cat(text_list) # 若干tensor组成的列表变成一个tensor return label_list.to(device), text_list.to(device), offsets.to(device)# dataloader = DataLoader(train_iter, batch_size=8, shuffle=False, collate_fn=collate_batch)# import ipdbfrom torch import nnclass TextClassificationModel(nn.Module): def __init__(self, vocab_size, embed_dim, num_class): super(TextClassificationModel, self).__init__() self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True) self.fc = nn.Linear(embed_dim, num_class) self.init_weights() def init_weights(self): initrange = 0.5 self.embedding.weight.data.uniform_(-initrange, initrange) # 将tensor用从均匀分布中抽样得到的值填充 self.fc.weight.data.uniform_(-initrange, initrange) self.fc.bias.data.zero_() def forward(self, text, offsets): embedded = self.embedding(text, offsets) # torch.Size([64, 64]) output = self.fc(embedded) # torch.Size([64, 4]) return output# num_class = len(set([label for (label, text) in train_iter])) # 迭代器需要重新开始才能计算...即train_iter = AG_NEWS(root=path, split='train') # 访问原始数据集迭代器num_class = 4vocab_size = len(vocab)emsize = 64model = TextClassificationModel(vocab_size, emsize, num_class).to(device)import timedef train(dataloader): model.train() # 训练模式 total_acc, total_count = 0, 0 log_interval = 500 start_time = time.time() for idx, (label, text, offsets) in enumerate(dataloader): optimizer.zero_grad() predited_label = model(text, offsets) loss = criterion(predited_label, label) loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1) # 规定了最大不能超过的max_norm optimizer.step() total_acc += (predited_label.argmax(1) == label).sum().item() total_count += label.size(0) if idx % log_interval == 0 and idx > 0: elapsed = time.time() - start_time print('| epoch {:3d} | {:5d}/{:5d} batches, accuracy {:8.3f}'.format(epoch, idx, len(dataloader), total_acc / total_count)) total_acc, total_count = 0, 0 start_time = time.time()def evaluate(dataloader): model.eval() total_acc, total_count = 0, 0 with torch.no_grad(): for idx, (label, text, offsets) in enumerate(dataloader): predited_label = model(text, offsets) # loss = criterion(predited_label, label) total_acc += (predited_label.argmax(1) == label).sum().item() total_count += label.size(0) return total_acc / total_countdef predict(text, text_pipeline): with torch.no_grad(): text = torch.tensor(text_pipeline(text)) output = model(text, torch.tensor([0])) return output.argmax(1).item() + 1from torch.utils.data.dataset import random_splitif __name__ == '__main__': # 超参数(Hyperparameters) # EPOCHS = 10 # epoch # LR = 5 # learning rate # BATCH_SIZE = 64 # batch size for training # # criterion = torch.nn.CrossEntropyLoss() # optimizer = torch.optim.SGD(model.parameters(), lr=LR) # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1) # total_accu = None # train_iter, test_iter = AG_NEWS(root=path) # train_dataset = list(train_iter) # test_dataset = list(test_iter) # num_train = int(len(train_dataset) * 0.95) # split_train_, split_valid_ = random_split(train_dataset, [num_train, len(train_dataset) - num_train]) # # train_dataloader = DataLoader(split_train_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch) # shuffle表示随机打乱 # valid_dataloader = DataLoader(split_valid_, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch) # test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_batch) # # for epoch in range(1, EPOCHS + 1): # epoch_start_time = time.time() # train(train_dataloader) # accu_val = evaluate(valid_dataloader) # if total_accu is not None and total_accu > accu_val: # scheduler.step() # else: # total_accu = accu_val # print('-' * 59) # print('| end of epoch {:3d} | time: {:5.2f}s | ' # 'valid accuracy {:8.3f} '.format(epoch, time.time() - epoch_start_time, accu_val)) # print('-' * 59) # # # print('Checking the results of test dataset.') # accu_test = evaluate(test_dataloader) # print('test accuracy {:8.3f}'.format(accu_test)) # # torch.save(model.state_dict(), '... your path\\model_TextClassification.pth') # 以下是评估 model.load_state_dict(torch.load('... your path\\model_TextClassification.pth')) ag_news_label = { 1: "World", 2: "Sports", 3: "Business", 4: "Sci/Tec"} # ex_text_str = "MEMPHIS, Tenn. – Four days ago, Jon Rahm was enduring the season’s worst weather conditions on Sunday at The Open on his way to a closing 75 at Royal Portrush, which considering the wind and the rain was a respectable showing. Thursday’s first round at the WGC-FedEx St. Jude Invitational was another story. With temperatures in the mid-80s and hardly any wind, the Spaniard was 13 strokes better in a flawless round. Thanks to his best putting performance on the PGA Tour, Rahm finished with an 8-under 62 for a three-stroke lead, which was even more impressive considering he’d never played the front nine at TPC Southwind." ex_text_str = 'Beijing of Automation, Beijing Institute of Technology' # model = model.to("cpu") print("This is a %s news" % ag_news_label[predict(ex_text_str, text_pipeline)])
未来工作:
转载地址:http://dwtrn.baihongyu.com/