challenging-america-word-ga.../lstm.py
2023-06-04 17:07:15 +02:00

189 lines
5.3 KiB
Python

import torch
from torch import nn, optim
from torch.utils.data import DataLoader
import numpy as np
from collections import Counter
import string
import lzma
import pdb
import copy
from torch.utils.data import IterableDataset
import itertools
import lzma
import regex as re
import pickle
import string
import pdb
import utils
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
device = 'cuda'
with open("vocab.pickle", 'rb') as handle:
vocab = pickle.load( handle)
vocab.set_default_index(vocab['<unk>'])
def get_word_lines_from_file(file_name):
counter=0
seq_len = 10
with lzma.open(file_name, 'r') as fh:
for line in fh:
counter+=1
# if counter == 100000:
# break
line = line.decode("utf-8")
line_splitted = utils.get_words_from_line(line)
vocab_line = [vocab[t] for t in line_splitted]
for i in range(len(vocab_line) - seq_len):
yield torch.tensor(vocab_line[i:i+seq_len]), torch.tensor(vocab_line[i+1 :i+seq_len+1])
class Grams_10(IterableDataset):
def __init__(self, text_file, vocab):
self.vocab = vocab
self.vocab.set_default_index(self.vocab['<unk>'])
self.text_file = text_file
def __iter__(self):
return get_word_lines_from_file(self.text_file)
vocab_size = utils.vocab_size
train_dataset = Grams_10('train/in.tsv.xz', vocab)
BATCH_SIZE = 1024
class Model(nn.Module):
def __init__(self, vocab_size):
super(Model, self).__init__()
self.lstm_size = 150
self.embedding_dim = 200
self.num_layers = 1
self.embedding = nn.Embedding(
num_embeddings=vocab_size,
embedding_dim=self.embedding_dim,
)
self.lstm = nn.LSTM(
input_size=self.embedding_dim,
hidden_size=self.lstm_size,
num_layers=self.num_layers,
batch_first=True,
bidirectional=True,
# dropout=0.2,
)
self.fc = nn.Linear(self.lstm_size*2, vocab_size)
def forward(self, x, prev_state = None):
embed = self.embedding(x)
output, state = self.lstm(embed, prev_state)
logits = self.fc(output)
return logits, state
def init_state(self, sequence_length):
return (torch.zeros(self.num_layers*2, sequence_length, self.lstm_size).to(device),
torch.zeros(self.num_layers*2, sequence_length, self.lstm_size).to(device))
def train(dataloader, model, max_epochs):
model.train()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)
for epoch in range(max_epochs):
step = 0
for batch_i, (x, y) in enumerate(dataloader):
# pdb.set_trace()
x = x.to(device)
y = y.to(device)
optimizer.zero_grad()
y_pred, (state_h, state_c) = model(x)
# pdb.set_trace()
loss = criterion(y_pred.transpose(1, 2), y)
loss.backward()
optimizer.step()
step+=1
if step % 500 == 0:
print({ 'epoch': epoch,'step': step ,'loss': loss.item(), })
# torch.save(model.state_dict(), f'lstm_step_{step}.bin')
if step % 5000 == 0:
print({ 'epoch': epoch, 'step': step, 'loss': loss.item() })
torch.save(model.state_dict(), f'lstm_step_{step}.bin')
torch.save(model.state_dict(), f'lstm_epoch_{epoch}.bin')
# break
print('Halko zaczynamy trenowanie')
model = Model(vocab_size = vocab_size).to(device)
dataset = DataLoader(train_dataset, batch_size=BATCH_SIZE)
train(dataset, model, 1)
torch.save(model.state_dict(), f'lstm.bin')
# def predict(model, text_splitted):
# model.eval()
# words = text_splitted
# x = torch.tensor([[vocab[w] for w in words]]).to(device)
# state_h, state_c = model.init_state(x.size()[0])
# y_pred, (state_h, state_c) = model(x, (state_h, state_c))
# last_word_logits = y_pred[0][-1]
# p = torch.nn.functional.softmax(last_word_logits, dim=0)
# top = torch.topk(p, 64)
# top_indices = top.indices.tolist()
# top_probs = top.values.tolist()
# top_words = vocab.lookup_tokens(top_indices)
# return top_words, top_probs
# print('Halko zaczynamy predykcje')
# inference_result = []
# with lzma.open(f'dev-0/in.tsv.xz', 'r') as file:
# for line in file:
# line = line.decode("utf-8")
# line = line.rstrip()
# line = line.translate(str.maketrans('', '', string.punctuation))
# line_splitted_by_tab = line.split('\t')
# left_context = line_splitted_by_tab[-2]
# left_context_splitted = list(utils.get_words_from_line(left_context))
# top_words, top_probs = predict(model, left_context_splitted)
# string_to_print = ''
# sum_probs = 0
# for w, p in zip(top_words, top_probs):
# # print(top_words)
# if '<unk>' in w:
# continue
# string_to_print += f"{w}:{p} "
# sum_probs += p
# if string_to_print == '':
# inference_result.append("the:0.2 a:0.3 :0.5")
# continue
# unknow_prob = 1 - sum_probs
# string_to_print += f":{unknow_prob}"
# inference_result.append(string_to_print)
# with open('dev-0/out.tsv', 'w') as f:
# for line in inference_result:
# f.write(line+'\n')
print('All done')