8.3 KiB
8.3 KiB
import torch
from torch import nn
vocab_size = 20000
embed_size = 100
class SimpleTrigramNeuralLanguageModel(nn.Module):
def __init__(self, vocabulary_size, embedding_size):
super(SimpleTrigramNeuralLanguageModel, self).__init__()
self.embedding = nn.Embedding(vocabulary_size, embedding_size)
self.linear = nn.Linear(embedding_size, vocabulary_size)
def forward(self, x):
x = self.embedding(x)
x = self.linear(x)
x = torch.softmax(x, dim=1)
return x
import regex as re
from itertools import islice, chain
from torchtext.vocab import build_vocab_from_iterator
from torch.utils.data import IterableDataset
def get_words_from_line(line):
line = line.rstrip()
yield '<s>'
for m in re.finditer(r'[\p{L}0-9\*]+|\p{P}+', line):
yield m.group(0).lower()
yield '</s>'
def get_word_lines_from_file(file_name):
with open(file_name, 'r') as fh:
for line in fh:
yield get_words_from_line(line)
def look_ahead_iterator(gen):
prev = None
for item in gen:
if prev is not None:
yield (prev, item)
prev = item
class Bigrams(IterableDataset):
def __init__(self, text_file, vocabulary_size):
self.vocab = build_vocab_from_iterator(
get_word_lines_from_file(text_file),
max_tokens = vocabulary_size,
specials = ['<unk>']
)
self.vocab.set_default_index(self.vocab['<unk>'])
self.vocabulary_size = vocabulary_size
self.text_file = text_file
def __iter__(self):
return look_ahead_iterator((self.vocab[t] for t in chain.from_iterable(get_word_lines_from_file(self.text_file))))
from torch.utils.data import DataLoader
device = 'cuda'
train_dataset = Bigrams('europarl.txt', vocab_size)
model = SimpleTrigramNeuralLanguageModel(vocab_size, embed_size).to(device)
data = DataLoader(train_dataset, batch_size=2000)
optimizer = torch.optim.Adam(model.parameters())
criterion = torch.nn.NLLLoss()
[0;31m---------------------------------------------------------------------------[0m [0;31mNameError[0m Traceback (most recent call last) [0;32m/tmp/ipykernel_16179/3272155308.py[0m in [0;36m<module>[0;34m[0m [1;32m 3[0m [0mdevice[0m [0;34m=[0m [0;34m'cuda'[0m[0;34m[0m[0;34m[0m[0m [1;32m 4[0m [0mtrain_dataset[0m [0;34m=[0m [0mBigrams[0m[0;34m([0m[0;34m'europarl.txt'[0m[0;34m,[0m [0mvocab_size[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [0;32m----> 5[0;31m [0mmodel[0m [0;34m=[0m [0mSimpleTrigramNeuralLanguageModel[0m[0;34m([0m[0mvocab_size[0m[0;34m,[0m [0membed_size[0m[0;34m)[0m[0;34m.[0m[0mto[0m[0;34m([0m[0mdevice[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [0m[1;32m 6[0m [0mdata[0m [0;34m=[0m [0mDataLoader[0m[0;34m([0m[0mtrain_dataset[0m[0;34m,[0m [0mbatch_size[0m[0;34m=[0m[0;36m2000[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [1;32m 7[0m [0moptimizer[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0moptim[0m[0;34m.[0m[0mAdam[0m[0;34m([0m[0mmodel[0m[0;34m.[0m[0mparameters[0m[0;34m([0m[0;34m)[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [0;32m/tmp/ipykernel_16179/1892442743.py[0m in [0;36m__init__[0;34m(self, vocabulary_size, embedding_size)[0m [1;32m 4[0m [0;32mclass[0m [0mSimpleTrigramNeuralLanguageModel[0m[0;34m([0m[0mnn[0m[0;34m.[0m[0mModule[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m [1;32m 5[0m [0;32mdef[0m [0m__init__[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mvocabulary_size[0m[0;34m,[0m [0membedding_size[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m [0;32m----> 6[0;31m [0msuper[0m[0;34m([0m[0mSimpleBigramNeuralLanguageModel[0m[0;34m,[0m [0mself[0m[0;34m)[0m[0;34m.[0m[0m__init__[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [0m[1;32m 7[0m [0mself[0m[0;34m.[0m[0membedding[0m [0;34m=[0m [0mnn[0m[0;34m.[0m[0mEmbedding[0m[0;34m([0m[0mvocabulary_size[0m[0;34m,[0m [0membedding_size[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [1;32m 8[0m [0mself[0m[0;34m.[0m[0mlinear[0m [0;34m=[0m [0mnn[0m[0;34m.[0m[0mLinear[0m[0;34m([0m[0membedding_size[0m[0;34m,[0m [0mvocabulary_size[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [0;31mNameError[0m: name 'SimpleBigramNeuralLanguageModel' is not defined
for epoch in range(1):
model.train()
for x, y in data:
x = x.to(device)
y = y.to(device)
optimizer.zero_grad()
outputs = model(x)
loss = criterion(torch.log(outputs), y)
if step % 100 == 0:
print(step, loss)
step += 1
loss.backward()
optimizer.step()
torch.save(model.state_dict(), 'model/model1.bin')