语言模型是根据当前词预测下一个词,一次构造的数据集中inputs假如为 abcdefg,则target为bcdefgh。
数据使用text8,无标点无换行的英文数据集。以下使用torchtext进行数据集的预处理。数据预处理的代码如下:
import torchtext
import torch
MAX_VOCAB_SIZE = 50000
BATCH_SIZE = 32
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
TEXT = torchtext.data.Field(lower=True)
train, val, test = torchtext.datasets.LanguageModelingDataset.splits(path="data", train="text8.train.txt",
validation="text8.dev.txt", test="text8.test.txt",
text_field=TEXT)
# 构造词典
TEXT.build_vocab(train, max_size=MAX_VOCAB_SIZE)
print(TEXT.vocab.itos[:10])
print(TEXT.vocab.stoi["apple"])
# 构造dataLoader
train_iter, val_iter, test_iter = torchtext.data.BPTTIterator.splits((train, val, test), batch_size=BATCH_SIZE,
device=device, bptt_len=50, repeat=False, shuffle=True)
VOCAB_SIZE = len(TEXT.vocab)
print(VOCAB_SIZE)
for batch in train_iter:
data, target = batch.text, batch.target
# data.shape:(seq_len, batch_size)
# target.shape:(seq_len, batch_size)
print(data.shape)
print(target.shape)
break
来源:oschina
链接:https://my.oschina.net/u/4228078/blog/4462413