12 KiB
12 KiB
import torch
from torch import nn
vocab_size = 20000
embed_size = 100
class SimpleTrigramNeuralLanguageModel(nn.Module):
def __init__(self, vocabulary_size, embedding_size):
super(SimpleTrigramNeuralLanguageModel, self).__init__()
self.embedding = nn.Embedding(vocabulary_size, embedding_size)
self.linear = nn.Linear(embedding_size, vocabulary_size)
def forward(self, x):
x = self.embedding(x)
x = self.linear(x)
x = torch.softmax(x, dim=1)
return x
import regex as re
from itertools import islice, chain
from torchtext.vocab import build_vocab_from_iterator
from torch.utils.data import IterableDataset
def get_words_from_line(line):
line = line.rstrip()
yield '<s>'
for m in re.finditer(r'[\p{L}0-9\*]+|\p{P}+', line):
yield m.group(0).lower()
yield '</s>'
def get_word_lines_from_file(file_name):
with open(file_name, 'r') as fh:
for line in fh:
yield get_words_from_line(line)
def look_ahead_iterator(gen):
prev = None
for item in gen:
if prev is not None:
yield (prev, item)
prev = item
class Bigrams(IterableDataset):
def __init__(self, text_file, vocabulary_size):
self.vocab = build_vocab_from_iterator(
get_word_lines_from_file(text_file),
max_tokens = vocabulary_size,
specials = ['<unk>']
)
self.vocab.set_default_index(self.vocab['<unk>'])
self.vocabulary_size = vocabulary_size
self.text_file = text_file
def __iter__(self):
return look_ahead_iterator((self.vocab[t] for t in chain.from_iterable(get_word_lines_from_file(self.text_file))))
from torch.utils.data import DataLoader
device = 'cpu'
train_dataset = Bigrams('europarl.txt', vocab_size)
model = SimpleTrigramNeuralLanguageModel(vocab_size, embed_size).to(device)
data = DataLoader(train_dataset, batch_size=2000)
optimizer = torch.optim.Adam(model.parameters())
criterion = torch.nn.NLLLoss()
step = 0
for epoch in range(1):
model.train()
for x, y in data:
x = x.to(device)
y = y.to(device)
optimizer.zero_grad()
outputs = model(x)
loss = criterion(torch.log(outputs), y)
if step % 100 == 0:
print(step, loss)
step += 1
loss.backward()
optimizer.step()
torch.save(model.state_dict(), 'model/model1.bin')
0 tensor(10.0424, grad_fn=<NllLossBackward0>) 100 tensor(7.9016, grad_fn=<NllLossBackward0>) 200 tensor(7.1964, grad_fn=<NllLossBackward0>) 300 tensor(6.5661, grad_fn=<NllLossBackward0>) 400 tensor(6.4146, grad_fn=<NllLossBackward0>) 500 tensor(5.8718, grad_fn=<NllLossBackward0>)
[1;31m---------------------------------------------------------------------------[0m [1;31mKeyboardInterrupt[0m Traceback (most recent call last) [1;32m<ipython-input-8-06724a2e87a4>[0m in [0;36m<module>[1;34m[0m [0;32m 7[0m [0my[0m [1;33m=[0m [0my[0m[1;33m.[0m[0mto[0m[1;33m([0m[0mdevice[0m[1;33m)[0m[1;33m[0m[1;33m[0m[0m [0;32m 8[0m [0moptimizer[0m[1;33m.[0m[0mzero_grad[0m[1;33m([0m[1;33m)[0m[1;33m[0m[1;33m[0m[0m [1;32m----> 9[1;33m [0moutputs[0m [1;33m=[0m [0mmodel[0m[1;33m([0m[0mx[0m[1;33m)[0m[1;33m[0m[1;33m[0m[0m [0m[0;32m 10[0m [0mloss[0m [1;33m=[0m [0mcriterion[0m[1;33m([0m[0mtorch[0m[1;33m.[0m[0mlog[0m[1;33m([0m[0moutputs[0m[1;33m)[0m[1;33m,[0m [0my[0m[1;33m)[0m[1;33m[0m[1;33m[0m[0m [0;32m 11[0m [1;32mif[0m [0mstep[0m [1;33m%[0m [1;36m100[0m [1;33m==[0m [1;36m0[0m[1;33m:[0m[1;33m[0m[1;33m[0m[0m [1;32m~\AppData\Roaming\Python\Python39\site-packages\torch\nn\modules\module.py[0m in [0;36m_call_impl[1;34m(self, *input, **kwargs)[0m [0;32m 1108[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks [0;32m 1109[0m or _global_forward_hooks or _global_forward_pre_hooks): [1;32m-> 1110[1;33m [1;32mreturn[0m [0mforward_call[0m[1;33m([0m[1;33m*[0m[0minput[0m[1;33m,[0m [1;33m**[0m[0mkwargs[0m[1;33m)[0m[1;33m[0m[1;33m[0m[0m [0m[0;32m 1111[0m [1;31m# Do not call functions when jit is used[0m[1;33m[0m[1;33m[0m[1;33m[0m[0m [0;32m 1112[0m [0mfull_backward_hooks[0m[1;33m,[0m [0mnon_full_backward_hooks[0m [1;33m=[0m [1;33m[[0m[1;33m][0m[1;33m,[0m [1;33m[[0m[1;33m][0m[1;33m[0m[1;33m[0m[0m [1;32m<ipython-input-2-4f6f391f0eb8>[0m in [0;36mforward[1;34m(self, x)[0m [0;32m 10[0m [1;32mdef[0m [0mforward[0m[1;33m([0m[0mself[0m[1;33m,[0m [0mx[0m[1;33m)[0m[1;33m:[0m[1;33m[0m[1;33m[0m[0m [0;32m 11[0m [0mx[0m [1;33m=[0m [0mself[0m[1;33m.[0m[0membedding[0m[1;33m([0m[0mx[0m[1;33m)[0m[1;33m[0m[1;33m[0m[0m [1;32m---> 12[1;33m [0mx[0m [1;33m=[0m [0mself[0m[1;33m.[0m[0mlinear[0m[1;33m([0m[0mx[0m[1;33m)[0m[1;33m[0m[1;33m[0m[0m [0m[0;32m 13[0m [0mx[0m [1;33m=[0m [0mtorch[0m[1;33m.[0m[0msoftmax[0m[1;33m([0m[0mx[0m[1;33m,[0m [0mdim[0m[1;33m=[0m[1;36m1[0m[1;33m)[0m[1;33m[0m[1;33m[0m[0m [0;32m 14[0m [1;32mreturn[0m [0mx[0m[1;33m[0m[1;33m[0m[0m [1;32m~\AppData\Roaming\Python\Python39\site-packages\torch\nn\modules\module.py[0m in [0;36m_call_impl[1;34m(self, *input, **kwargs)[0m [0;32m 1108[0m if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks [0;32m 1109[0m or _global_forward_hooks or _global_forward_pre_hooks): [1;32m-> 1110[1;33m [1;32mreturn[0m [0mforward_call[0m[1;33m([0m[1;33m*[0m[0minput[0m[1;33m,[0m [1;33m**[0m[0mkwargs[0m[1;33m)[0m[1;33m[0m[1;33m[0m[0m [0m[0;32m 1111[0m [1;31m# Do not call functions when jit is used[0m[1;33m[0m[1;33m[0m[1;33m[0m[0m [0;32m 1112[0m [0mfull_backward_hooks[0m[1;33m,[0m [0mnon_full_backward_hooks[0m [1;33m=[0m [1;33m[[0m[1;33m][0m[1;33m,[0m [1;33m[[0m[1;33m][0m[1;33m[0m[1;33m[0m[0m [1;32m~\AppData\Roaming\Python\Python39\site-packages\torch\nn\modules\linear.py[0m in [0;36mforward[1;34m(self, input)[0m [0;32m 101[0m [1;33m[0m[0m [0;32m 102[0m [1;32mdef[0m [0mforward[0m[1;33m([0m[0mself[0m[1;33m,[0m [0minput[0m[1;33m:[0m [0mTensor[0m[1;33m)[0m [1;33m->[0m [0mTensor[0m[1;33m:[0m[1;33m[0m[1;33m[0m[0m [1;32m--> 103[1;33m [1;32mreturn[0m [0mF[0m[1;33m.[0m[0mlinear[0m[1;33m([0m[0minput[0m[1;33m,[0m [0mself[0m[1;33m.[0m[0mweight[0m[1;33m,[0m [0mself[0m[1;33m.[0m[0mbias[0m[1;33m)[0m[1;33m[0m[1;33m[0m[0m [0m[0;32m 104[0m [1;33m[0m[0m [0;32m 105[0m [1;32mdef[0m [0mextra_repr[0m[1;33m([0m[0mself[0m[1;33m)[0m [1;33m->[0m [0mstr[0m[1;33m:[0m[1;33m[0m[1;33m[0m[0m [1;31mKeyboardInterrupt[0m: