Delete 'train.py'
This commit is contained in:
parent
6b714b7556
commit
d713b4a83f
124
train.py
124
train.py
@ -1,124 +0,0 @@
|
|||||||
|
|
||||||
|
|
||||||
from torch import nn
|
|
||||||
import torch
|
|
||||||
|
|
||||||
|
|
||||||
from torch.utils.data import IterableDataset
|
|
||||||
import itertools
|
|
||||||
import lzma
|
|
||||||
import regex as re
|
|
||||||
import pickle
|
|
||||||
import scripts
|
|
||||||
|
|
||||||
|
|
||||||
def look_ahead_iterator(gen):
|
|
||||||
prev = None
|
|
||||||
current = None
|
|
||||||
next = None
|
|
||||||
for next in gen:
|
|
||||||
if prev is not None and current is not None:
|
|
||||||
yield (prev, current, next)
|
|
||||||
prev = current
|
|
||||||
current = next
|
|
||||||
|
|
||||||
|
|
||||||
def get_word_lines_from_file(file_name):
|
|
||||||
counter=0
|
|
||||||
with lzma.open(file_name, 'r') as fh:
|
|
||||||
for line in fh:
|
|
||||||
counter+=1
|
|
||||||
if counter == 100000:
|
|
||||||
break
|
|
||||||
line = line.decode("utf-8")
|
|
||||||
yield scripts.get_words_from_line(line)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Trigrams(IterableDataset):
|
|
||||||
def load_vocab(self):
|
|
||||||
with open("vocab.pickle", 'rb') as handle:
|
|
||||||
vocab = pickle.load( handle)
|
|
||||||
return vocab
|
|
||||||
|
|
||||||
def __init__(self, text_file, vocabulary_size):
|
|
||||||
self.vocab = self.load_vocab()
|
|
||||||
self.vocab.set_default_index(self.vocab['<unk>'])
|
|
||||||
self.vocabulary_size = vocabulary_size
|
|
||||||
self.text_file = text_file
|
|
||||||
|
|
||||||
def __iter__(self):
|
|
||||||
return look_ahead_iterator(
|
|
||||||
(self.vocab[t] for t in itertools.chain.from_iterable(get_word_lines_from_file(self.text_file))))
|
|
||||||
|
|
||||||
vocab_size = scripts.vocab_size
|
|
||||||
|
|
||||||
train_dataset = Trigrams('train/in.tsv.xz', vocab_size)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
#=== trenowanie
|
|
||||||
from torch import nn
|
|
||||||
import torch
|
|
||||||
from torch.utils.data import DataLoader
|
|
||||||
embed_size = 100
|
|
||||||
|
|
||||||
class SimpleTrigramNeuralLanguageModel(nn.Module):
|
|
||||||
def __init__(self, vocabulary_size, embedding_size):
|
|
||||||
super(SimpleTrigramNeuralLanguageModel, self).__init__()
|
|
||||||
self.embedings = nn.Embedding(vocabulary_size, embedding_size)
|
|
||||||
self.linear = nn.Linear(embedding_size*2, vocabulary_size)
|
|
||||||
|
|
||||||
self.linear_first_layer = nn.Linear(embedding_size*2, embedding_size*2)
|
|
||||||
self.relu = nn.ReLU()
|
|
||||||
self.softmax = nn.Softmax()
|
|
||||||
|
|
||||||
# self.model = nn.Sequential(
|
|
||||||
# nn.Embedding(vocabulary_size, embedding_size),
|
|
||||||
# nn.Linear(embedding_size, vocabulary_size),
|
|
||||||
# nn.Softmax()
|
|
||||||
# )
|
|
||||||
|
|
||||||
def forward(self, x):
|
|
||||||
emb_1 = self.embedings(x[0])
|
|
||||||
emb_2 = self.embedings(x[1])
|
|
||||||
|
|
||||||
first_layer = self.linear_first_layer(torch.cat((emb_1, emb_2), dim=1))
|
|
||||||
after_relu = self.relu(first_layer)
|
|
||||||
concated = self.linear(after_relu)
|
|
||||||
|
|
||||||
y = self.softmax(concated)
|
|
||||||
|
|
||||||
return y
|
|
||||||
|
|
||||||
model = SimpleTrigramNeuralLanguageModel(vocab_size, embed_size)
|
|
||||||
|
|
||||||
vocab = train_dataset.vocab
|
|
||||||
|
|
||||||
|
|
||||||
device = 'cuda'
|
|
||||||
model = SimpleTrigramNeuralLanguageModel(vocab_size, embed_size).to(device)
|
|
||||||
data = DataLoader(train_dataset, batch_size=12800)
|
|
||||||
optimizer = torch.optim.Adam(model.parameters(), lr=scripts.learning_rate)
|
|
||||||
criterion = torch.nn.NLLLoss()
|
|
||||||
|
|
||||||
model.train()
|
|
||||||
step = 0
|
|
||||||
epochs = 4
|
|
||||||
for i in range(epochs):
|
|
||||||
for x, y, z in data:
|
|
||||||
x = x.to(device)
|
|
||||||
y = y.to(device)
|
|
||||||
z = z.to(device)
|
|
||||||
optimizer.zero_grad()
|
|
||||||
ypredicted = model([x, z])
|
|
||||||
loss = criterion(torch.log(ypredicted), y)
|
|
||||||
if step % 2000 == 0:
|
|
||||||
print(step, loss)
|
|
||||||
# torch.save(model.state_dict(), f'model1_{step}.bin')
|
|
||||||
step += 1
|
|
||||||
loss.backward()
|
|
||||||
optimizer.step()
|
|
||||||
torch.save(model.state_dict(), f'batch_model_epoch_{i}.bin')
|
|
||||||
print(step, loss, f'model_epoch_{i}.bin')
|
|
||||||
torch.save(model.state_dict(), 'model_tri1.bin')
|
|
Loading…
Reference in New Issue
Block a user