From 808efa0aade4127b05078734ea4b36546cfc617e Mon Sep 17 00:00:00 2001 From: wangobango Date: Wed, 16 Jun 2021 22:23:55 +0200 Subject: [PATCH] 123 --- .gitignore | 4 ++- main.py | 80 ++++++++++++++++++++++++++++++------------------------ util.py | 5 ++-- 3 files changed, 50 insertions(+), 39 deletions(-) diff --git a/.gitignore b/.gitignore index 4318506..90284b4 100644 --- a/.gitignore +++ b/.gitignore @@ -7,4 +7,6 @@ .DS_Store .token .vscode -fa/* \ No newline at end of file +fa/* +*.pt +*.pickle \ No newline at end of file diff --git a/main.py b/main.py index 1d30bed..0949e3e 100644 --- a/main.py +++ b/main.py @@ -4,6 +4,7 @@ import spacy import torch import numpy as np import tqdm +import pickle def clean_string(str): str = str.replace('\n', '') @@ -36,7 +37,7 @@ def compute_class_vector(mark, classes): result = np.zeros(len(classes)) for x in range(len(classes)): if classes[x] == mark[0]: - result[x] == 1 + result[x] = 1.0 return torch.tensor(result, dtype=torch.long) def prepare_input(index, data, context_size): @@ -65,7 +66,10 @@ mode = "train" data_paths = os.listdir(data_dir) data_paths = [data_dir + "/" + x for x in data_paths] -classes = [',', '.', '?', '!', '-', ':', '...'] +data_no_punc_paths = os.listdir(data_nopunc_dir) +data_no_punc_paths = [data_nopunc_dir + "/" + x for x in data_no_punc_paths] + +classes = [',', '.', '?', '!', '-', ':', '...', ''] nlp = spacy.load("pl_core_news_sm") context_size = 5 @@ -93,8 +97,8 @@ if mode == "train": for epoch in range(epochs): for path in tqdm.tqdm(data_paths): with open(path, mode="r", encoding="utf-8") as file: - list = file.readlines()[:-1] - for i in range(0, len(list) - context_size - 1): + list = file.readlines() + for i in range(0, 10): model.zero_grad() x = list[i: i + context_size] x = [line2word(y) for y in x] @@ -109,7 +113,7 @@ if mode == "train": x = x.to(device) output, (hidden_state, cell_state) = model.forward(x, hidden_state, cell_state) output = output.squeeze(1) - loss = loss_function(output, compute_class_vector(mark, classes).to(device)) + loss = loss_function(output, compute_class_vector([mark], classes).to(device)) loss.backward() optimizer.step() hidden_state = hidden_state.detach() @@ -120,43 +124,47 @@ if mode == "train": model.state_dict(), os.path.join("./", f"{output_prefix}-{epoch}.pt"), ) + with open("hidden_state.pickle", "wb") as hs: + pickle.dump(hidden_state, hs) + with open("cell_state.pickle", "wb") as cs: + pickle.dump(cell_state, cs) elif mode == "evaluate": correct = 0 incorrect = 0 threshold = 0.3 - for pathA, pathB in zip(data_nopunc_dir, data_dir): - listA = [] - listB = [] + model.load_state_dict(torch.load("model-0.pt")) + model.eval() + for pathA, pathB in zip(data_no_punc_paths, data_paths): with open(pathA, mode="r", encoding='utf-8') as file: - listA = file.readlines()[:-1] - with open(pathB, mode="r", encoding='utf-8') as file: - listb = file.readlines()[:-1] - for i in range(0, len(listA) - context_size - 1): - model.zero_grad() - x = listA[i: i + context_size] - x = [line2word(y) for y in x] - x_1 = [line2word(listA[i + context_size + 1])] + with open(pathB, mode="r", encoding='utf-8') as file2: + listA = file.readlines()[:-1] + listB = file2.readlines()[:-1] + for i in range(0, len(listA) - context_size - 1): + model.zero_grad() + x = listA[i: i + context_size] + x = [line2word(y) for y in x] + x_1 = [line2word(listA[i + context_size + 1])] - x = x + x_1 - x = words_to_vecs(x) + x = x + x_1 + x = words_to_vecs(x) - y = listB[i + context_size] - y = [line2word(x) for x in y] - mark_y = find_interpunction(y) - x = torch.tensor(x, dtype=torch.float) + mark_y = find_interpunction(x[-1], classes) + x = torch.tensor(x, dtype=torch.float).to(device) - output, (hidden_state, cell_state) = model.forward(x, hidden_state, cell_state) - result_index = np.argmax(output) - if output[result_index] < threshold and len(mark_y) == 0: - correct += 1 - if len(mark_y) > 0: - if classes[np.argmax(output)] == mark_y: - correct += 1 - else: - incorrect += 1 - else: - incorrect += 1 - - accuracy = correct / (correct + incorrect) - print(accuracy) \ No newline at end of file + output, (hidden_state, cell_state) = model.forward(x, hidden_state, cell_state) + output = output.cpu() + output = output.detach().numpy() + result_index = np.argmax(output) + if output[result_index] < threshold and len(mark_y) == 0: + correct += 1 + if len(mark_y) > 0: + if classes[np.argmax(output)] == mark_y: + correct += 1 + else: + incorrect += 1 + else: + incorrect += 1 + + accuracy = correct / (correct + incorrect) + print(accuracy) \ No newline at end of file diff --git a/util.py b/util.py index f1610b7..34e797f 100644 --- a/util.py +++ b/util.py @@ -17,13 +17,14 @@ class Model(torch.nn.Module): 2 num layers """ self.lstm = torch.nn.LSTM(150, 300, 2) - self.dense2 = torch.nn.Linear(300, 7) - self.softmax = torch.nn.Softmax(dim=1) + self.dense2 = torch.nn.Linear(300, 8) + self.softmax = torch.nn.Softmax(dim=0) def forward(self, data, hidden_state, cell_state): data = self.dense1(data.T) data = self.tanh1(data) data, (hidden_state, cell_state) = self.lstm(data.unsqueeze(1), (hidden_state, cell_state)) + # data, (hidden_state, cell_state) = self.lstm(data.unsqueeze(1)) data = self.dense2(data) data = self.softmax(data) return data, (hidden_state, cell_state) \ No newline at end of file