构造语言模型数据集

十年热恋 提交于 2020-08-04 11:19:33

语言模型是根据当前词预测下一个词,一次构造的数据集中inputs假如为 abcdefg,则target为bcdefgh。

数据使用text8,无标点无换行的英文数据集。以下使用torchtext进行数据集的预处理。数据预处理的代码如下:

import torchtext
import torch

MAX_VOCAB_SIZE = 50000
BATCH_SIZE = 32


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

TEXT = torchtext.data.Field(lower=True)

train, val, test = torchtext.datasets.LanguageModelingDataset.splits(path="data", train="text8.train.txt",
                                                                     validation="text8.dev.txt", test="text8.test.txt",
                                                                     text_field=TEXT)
# 构造词典
TEXT.build_vocab(train, max_size=MAX_VOCAB_SIZE)

print(TEXT.vocab.itos[:10])
print(TEXT.vocab.stoi["apple"])

# 构造dataLoader
train_iter, val_iter, test_iter = torchtext.data.BPTTIterator.splits((train, val, test), batch_size=BATCH_SIZE,
                                                                     device=device, bptt_len=50, repeat=False, shuffle=True)


VOCAB_SIZE = len(TEXT.vocab)
print(VOCAB_SIZE)

for batch in train_iter:
    data, target = batch.text, batch.target
    # data.shape:(seq_len, batch_size)
    # target.shape:(seq_len, batch_size)
    print(data.shape)
    print(target.shape)
    break

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!