import os from util import Model import spacy import torch import numpy as np import tqdm def clean_string(str): str = str.replace('\n', '') return str def extract_word(line): return line.split(" ")[1] def line2word(line): word = extract_word(line) word = clean_string(word) return word def find_interpunction(line, classes): result = [x for x in classes if x in line] if len(result) > 0: return result[0] else: return [''] def words_to_vecs(list_of_words): return [nlp(x).vector for x in list_of_words] def softXEnt(input, target): m = torch.nn.LogSoftmax(dim = 1) logprobs = m(input) return -(target * logprobs).sum() / input.shape[0] def compute_class_vector(mark, classes): result = np.zeros(len(classes)) for x in range(len(classes)): if classes[x] == mark[0]: result[x] == 1 return torch.tensor(result, dtype=torch.long) data_dir = "./fa/poleval_final_dataset/train" data_nopunc_dir = "./fa/poleval_final_dataset1/train" data_paths = os.listdir(data_dir) data_paths = [data_dir + "/" + x for x in data_paths] classes = [',', '.', '?', '!', '-', ':', '...'] nlp = spacy.load("pl_core_news_sm") context_size = 5 model = Model() epochs = 5 output_prefix = "model" hidden_state = torch.randn((2, 1, 300), requires_grad=True) cell_state = torch.randn((2, 1, 300), requires_grad=True) model.train() optimizer = torch.optim.AdamW(model.parameters(), lr=0.02) loss_function = softXEnt for epoch in range(epochs): for path in tqdm.tqdm(data_paths): with open(path, "r") as file: list = file.readlines()[:-1] for i in range(0, len(list) - context_size - 1): model.zero_grad() x = list[i: i + context_size] x = [line2word(y) for y in x] x_1 = [line2word(list[i + context_size + 1])] x = x + x_1 x = words_to_vecs(x) mark = find_interpunction(x, classes) mark = words_to_vecs(mark) x = torch.tensor(x, dtype=torch.float) mark = torch.tensor(mark, dtype=torch.float) output, (hidden_state, cell_state) = model.forward(x, hidden_state, cell_state) output = output.squeeze(1) loss = loss_function(output, compute_class_vector(mark, classes)) loss.backward() optimizer.step() hidden_state = hidden_state.detach() cell_state = cell_state.detach() """ vector -> (96,), np nadarray """ print("Epoch: {}".format(epoch)) torch.save( model.state_dict(), os.path.join("./", f"{output_prefix}-{epoch}.pt"), )