import os from typing import Counter from util import Model import spacy import torch import numpy as np import tqdm import pickle def clean_string(str): str = str.replace('\n', '') return str def extract_word(line): return line.split(" ")[1] def line2word(line): word = extract_word(line) word = clean_string(word) return word def find_interpunction(line, classes): result = [x for x in classes if x in line] if len(result) > 0: return result[0] else: return [''] def words_to_vecs(list_of_words): return [nlp(x).vector for x in list_of_words] def softXEnt(input, target): m = torch.nn.LogSoftmax(dim=1) logprobs = m(input) return -(target * logprobs).sum() / input.shape[0] def compute_class_vector(mark, classes): result = np.zeros(len(classes)) for x in range(len(classes)): if classes[x] == mark[0]: result[x] = 1.0 return torch.tensor(result, dtype=torch.float) def prepare_input(index, data, context_size): x = data[index: index + context_size] x = [line2word(y) for y in x] x_1 = [line2word(list[index + context_size + 1])] mark = find_interpunction(x[-1], classes) x = x + x_1 x = words_to_vecs(x) x = torch.tensor(x, dtype=torch.float) return x, compute_class_vector(mark, classes) def prepare_batch(index, data, context_size, batch_size): result = [] for i in range(index, index + batch_size): result.append(prepare_input(i, data, context_size)) return result data_dir = "./fa/poleval_final_dataset/train" data_nopunc_dir = "./fa/poleval_final_dataset1/train" mode = "train" # mode = "evaluate" # mode = "generate" data_paths = os.listdir(data_dir) data_paths = [data_dir + "/" + x for x in data_paths] data_no_punc_paths = os.listdir(data_nopunc_dir) data_no_punc_paths = [data_nopunc_dir + "/" + x for x in data_no_punc_paths] classes = [',', '.', '?', '!', '-', ':', '...', ''] nlp = spacy.load("pl_core_news_sm") context_size = 5 model = Model() epochs = 5 output_prefix = "model" train_loss_acc = 30 device=torch.device("cuda") model = model.cuda() optimizer = torch.optim.AdamW(model.parameters(), lr=0.001) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 15, 0.0001) loss_function = torch.nn.MSELoss() model.train() """ TODO 1) dodać przetwarzanie baczowe 2) zmienić loss function """ # hidden_state = torch.randn((2, 1, 300), requires_grad=True).to(device) # cell_state = torch.randn((2, 1, 300), requires_grad=True).to(device) counter = 0 # model.load_state_dict(torch.load("model-4.pt")) if mode == "train": for epoch in range(epochs): for path in tqdm.tqdm(data_paths[:20]): with open(path, mode="r", encoding="utf-8") as file: list = file.readlines()[:100] for i in range(0, len(list) - context_size - 1 - 1): model.zero_grad() x = list[i: i + context_size] x = [line2word(y) for y in x] x_1 = [line2word(list[i + context_size + 1])] mark = find_interpunction(x[-1], classes) if mark == '': continue x = x + x_1 x = words_to_vecs(x) x = torch.tensor(x, dtype=torch.float) x = x.to(device) output, (_,_) = model.forward(x) output = output.squeeze(1) class_vector = compute_class_vector([mark], classes).to(device) loss = loss_function(torch.mean(output, 0), class_vector) if counter % 10 == 0: print(torch.mean(output, 0)) print(loss) print(class_vector) loss.backward() if counter % train_loss_acc == 0: scheduler.step() optimizer.step() optimizer.zero_grad() model.zero_grad() counter += 1 print("Epoch: {}".format(epoch)) torch.save( model.state_dict(), os.path.join("./", f"{output_prefix}-{epoch}.pt"), ) # with open("hidden_state.pickle", "wb") as hs: # pickle.dump(hidden_state, hs) # with open("cell_state.pickle", "wb") as cs: # pickle.dump(cell_state, cs) elif mode == "evaluate": correct = 0 incorrect = 0 threshold = 0.3 model.load_state_dict(torch.load("model-4.pt")) model.eval() # with open("hidden_state.pickle", "rb") as hs: # hidden_state = pickle.load(hs) # with open("cell_state.pickle", "rb") as cs: # cell_state = pickle.load(cs) for pathA, pathB in zip(data_no_punc_paths, data_paths): with open(pathA, mode="r", encoding='utf-8') as file: with open(pathB, mode="r", encoding='utf-8') as file2: listA = file.readlines()[:-1] listB = file2.readlines()[:-1] for i in range(0, len(listA) - context_size - 1): model.zero_grad() x = listA[i: i + context_size] x = [line2word(y) for y in x] x_1 = [line2word(listA[i + context_size + 1])] x = x + x_1 x = words_to_vecs(x) mark_y = find_interpunction(x[-1], classes) x = torch.tensor(x, dtype=torch.float).to(device) output, (hidden_state, cell_state) = model.forward(x) output = output.cpu() output = output.detach().numpy() output = np.mean(output, axis=0) output = np.squeeze(output) result_index = np.argmax(output) # if output[result_index] < threshold: # incorrect += 1 print(output) if len(mark_y) > 0: if classes[np.argmax(output)] == mark_y: correct += 1 else: incorrect += 1 else: incorrect += 1 accuracy = correct / (correct + incorrect) print(accuracy)