189 lines
5.3 KiB
Python
189 lines
5.3 KiB
Python
import torch
|
|
from torch import nn, optim
|
|
from torch.utils.data import DataLoader
|
|
import numpy as np
|
|
from collections import Counter
|
|
import string
|
|
import lzma
|
|
import pdb
|
|
import copy
|
|
from torch.utils.data import IterableDataset
|
|
import itertools
|
|
import lzma
|
|
import regex as re
|
|
import pickle
|
|
import string
|
|
import pdb
|
|
import utils
|
|
import os
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
|
|
device = 'cuda'
|
|
|
|
with open("vocab.pickle", 'rb') as handle:
|
|
vocab = pickle.load( handle)
|
|
vocab.set_default_index(vocab['<unk>'])
|
|
|
|
def get_word_lines_from_file(file_name):
|
|
counter=0
|
|
seq_len = 10
|
|
with lzma.open(file_name, 'r') as fh:
|
|
for line in fh:
|
|
counter+=1
|
|
# if counter == 100000:
|
|
# break
|
|
line = line.decode("utf-8")
|
|
|
|
line_splitted = utils.get_words_from_line(line)
|
|
|
|
vocab_line = [vocab[t] for t in line_splitted]
|
|
|
|
for i in range(len(vocab_line) - seq_len):
|
|
yield torch.tensor(vocab_line[i:i+seq_len]), torch.tensor(vocab_line[i+1 :i+seq_len+1])
|
|
|
|
|
|
|
|
|
|
class Grams_10(IterableDataset):
|
|
|
|
def __init__(self, text_file, vocab):
|
|
self.vocab = vocab
|
|
self.vocab.set_default_index(self.vocab['<unk>'])
|
|
self.text_file = text_file
|
|
|
|
def __iter__(self):
|
|
return get_word_lines_from_file(self.text_file)
|
|
|
|
vocab_size = utils.vocab_size
|
|
|
|
train_dataset = Grams_10('train/in.tsv.xz', vocab)
|
|
|
|
|
|
BATCH_SIZE = 1024
|
|
|
|
|
|
class Model(nn.Module):
|
|
def __init__(self, vocab_size):
|
|
super(Model, self).__init__()
|
|
self.lstm_size = 150
|
|
self.embedding_dim = 200
|
|
self.num_layers = 1
|
|
|
|
self.embedding = nn.Embedding(
|
|
num_embeddings=vocab_size,
|
|
embedding_dim=self.embedding_dim,
|
|
)
|
|
self.lstm = nn.LSTM(
|
|
input_size=self.embedding_dim,
|
|
hidden_size=self.lstm_size,
|
|
num_layers=self.num_layers,
|
|
batch_first=True,
|
|
bidirectional=True,
|
|
# dropout=0.2,
|
|
)
|
|
self.fc = nn.Linear(self.lstm_size*2, vocab_size)
|
|
|
|
def forward(self, x, prev_state = None):
|
|
embed = self.embedding(x)
|
|
output, state = self.lstm(embed, prev_state)
|
|
logits = self.fc(output)
|
|
return logits, state
|
|
|
|
def init_state(self, sequence_length):
|
|
return (torch.zeros(self.num_layers*2, sequence_length, self.lstm_size).to(device),
|
|
torch.zeros(self.num_layers*2, sequence_length, self.lstm_size).to(device))
|
|
|
|
|
|
def train(dataloader, model, max_epochs):
|
|
model.train()
|
|
|
|
criterion = nn.CrossEntropyLoss()
|
|
optimizer = optim.Adam(model.parameters(), lr=0.01)
|
|
|
|
for epoch in range(max_epochs):
|
|
step = 0
|
|
for batch_i, (x, y) in enumerate(dataloader):
|
|
# pdb.set_trace()
|
|
|
|
x = x.to(device)
|
|
y = y.to(device)
|
|
optimizer.zero_grad()
|
|
|
|
y_pred, (state_h, state_c) = model(x)
|
|
# pdb.set_trace()
|
|
loss = criterion(y_pred.transpose(1, 2), y)
|
|
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
step+=1
|
|
if step % 500 == 0:
|
|
print({ 'epoch': epoch,'step': step ,'loss': loss.item(), })
|
|
# torch.save(model.state_dict(), f'lstm_step_{step}.bin')
|
|
if step % 5000 == 0:
|
|
print({ 'epoch': epoch, 'step': step, 'loss': loss.item() })
|
|
torch.save(model.state_dict(), f'lstm_step_{step}.bin')
|
|
torch.save(model.state_dict(), f'lstm_epoch_{epoch}.bin')
|
|
# break
|
|
print('Halko zaczynamy trenowanie')
|
|
model = Model(vocab_size = vocab_size).to(device)
|
|
|
|
dataset = DataLoader(train_dataset, batch_size=BATCH_SIZE)
|
|
train(dataset, model, 1)
|
|
torch.save(model.state_dict(), f'lstm.bin')
|
|
|
|
|
|
# def predict(model, text_splitted):
|
|
# model.eval()
|
|
# words = text_splitted
|
|
|
|
# x = torch.tensor([[vocab[w] for w in words]]).to(device)
|
|
# state_h, state_c = model.init_state(x.size()[0])
|
|
|
|
# y_pred, (state_h, state_c) = model(x, (state_h, state_c))
|
|
|
|
|
|
# last_word_logits = y_pred[0][-1]
|
|
# p = torch.nn.functional.softmax(last_word_logits, dim=0)
|
|
|
|
# top = torch.topk(p, 64)
|
|
# top_indices = top.indices.tolist()
|
|
# top_probs = top.values.tolist()
|
|
# top_words = vocab.lookup_tokens(top_indices)
|
|
# return top_words, top_probs
|
|
|
|
# print('Halko zaczynamy predykcje')
|
|
# inference_result = []
|
|
# with lzma.open(f'dev-0/in.tsv.xz', 'r') as file:
|
|
# for line in file:
|
|
# line = line.decode("utf-8")
|
|
# line = line.rstrip()
|
|
# line = line.translate(str.maketrans('', '', string.punctuation))
|
|
# line_splitted_by_tab = line.split('\t')
|
|
# left_context = line_splitted_by_tab[-2]
|
|
|
|
# left_context_splitted = list(utils.get_words_from_line(left_context))
|
|
|
|
# top_words, top_probs = predict(model, left_context_splitted)
|
|
|
|
# string_to_print = ''
|
|
|
|
# sum_probs = 0
|
|
# for w, p in zip(top_words, top_probs):
|
|
# # print(top_words)
|
|
# if '<unk>' in w:
|
|
# continue
|
|
# string_to_print += f"{w}:{p} "
|
|
# sum_probs += p
|
|
|
|
# if string_to_print == '':
|
|
# inference_result.append("the:0.2 a:0.3 :0.5")
|
|
# continue
|
|
# unknow_prob = 1 - sum_probs
|
|
# string_to_print += f":{unknow_prob}"
|
|
|
|
# inference_result.append(string_to_print)
|
|
|
|
# with open('dev-0/out.tsv', 'w') as f:
|
|
# for line in inference_result:
|
|
# f.write(line+'\n')
|
|
print('All done') |