From a894711f00aa0c0d9bc3f6afea8886cedb550b79 Mon Sep 17 00:00:00 2001 From: wangobango Date: Thu, 17 Jun 2021 17:54:10 +0200 Subject: [PATCH] dupa --- main.py | 67 ++++++++++++++++++++++++++++++++++++--------------------- util.py | 6 +++--- 2 files changed, 46 insertions(+), 27 deletions(-) diff --git a/main.py b/main.py index d742fe1..de67c26 100644 --- a/main.py +++ b/main.py @@ -1,4 +1,5 @@ import os +from typing import Counter from util import Model import spacy import torch @@ -38,7 +39,7 @@ def compute_class_vector(mark, classes): for x in range(len(classes)): if classes[x] == mark[0]: result[x] = 1.0 - return torch.tensor(result, dtype=torch.long) + return torch.tensor(result, dtype=torch.float) def prepare_input(index, data, context_size): x = data[index: index + context_size] @@ -76,28 +77,32 @@ context_size = 5 model = Model() epochs = 5 output_prefix = "model" +train_loss_acc = 30 device=torch.device("cuda") model = model.cuda() +optimizer = torch.optim.AdamW(model.parameters(), lr=0.001) +scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 15, 0.0001) +loss_function = torch.nn.MSELoss() model.train() -optimizer = torch.optim.AdamW(model.parameters(), lr=0.02) -loss_function = softXEnt """ TODO 1) dodać przetwarzanie baczowe -2) dodać osobną sieć w pełni połączoną która używa dźwięku żeby wykrywać czy użyć interpunkcji czy nie +2) zmienić loss function """ -hidden_state = torch.randn((2, 1, 300), requires_grad=True).to(device) -cell_state = torch.randn((2, 1, 300), requires_grad=True).to(device) +# hidden_state = torch.randn((2, 1, 300), requires_grad=True).to(device) +# cell_state = torch.randn((2, 1, 300), requires_grad=True).to(device) +counter = 0 +# model.load_state_dict(torch.load("model-4.pt")) if mode == "train": for epoch in range(epochs): - for path in tqdm.tqdm(data_paths): + for path in tqdm.tqdm(data_paths[:20]): with open(path, mode="r", encoding="utf-8") as file: - list = file.readlines() + list = file.readlines()[:100] for i in range(0, len(list) - context_size - 1 - 1): model.zero_grad() x = list[i: i + context_size] @@ -105,40 +110,54 @@ if mode == "train": x_1 = [line2word(list[i + context_size + 1])] mark = find_interpunction(x[-1], classes) + if mark == '': + continue + x = x + x_1 x = words_to_vecs(x) x = torch.tensor(x, dtype=torch.float) x = x.to(device) - output, (hidden_state, cell_state) = model.forward(x, hidden_state, cell_state) + output, (_,_) = model.forward(x) output = output.squeeze(1) - loss = loss_function(output, compute_class_vector([mark], classes).to(device)) + class_vector = compute_class_vector([mark], classes).to(device) + loss = loss_function(torch.mean(output, 0), class_vector) + + if counter % 10 == 0: + print(torch.mean(output, 0)) + print(loss) + print(class_vector) + loss.backward() - optimizer.step() - hidden_state = hidden_state.detach() - cell_state = cell_state.detach() + if counter % train_loss_acc == 0: + scheduler.step() + optimizer.step() + optimizer.zero_grad() + model.zero_grad() + + counter += 1 print("Epoch: {}".format(epoch)) torch.save( model.state_dict(), os.path.join("./", f"{output_prefix}-{epoch}.pt"), ) - with open("hidden_state.pickle", "wb") as hs: - pickle.dump(hidden_state, hs) - with open("cell_state.pickle", "wb") as cs: - pickle.dump(cell_state, cs) + # with open("hidden_state.pickle", "wb") as hs: + # pickle.dump(hidden_state, hs) + # with open("cell_state.pickle", "wb") as cs: + # pickle.dump(cell_state, cs) elif mode == "evaluate": correct = 0 incorrect = 0 threshold = 0.3 - model.load_state_dict(torch.load("model-0.pt")) + model.load_state_dict(torch.load("model-4.pt")) model.eval() - with open("hidden_state.pickle", "rb") as hs: - hidden_state = pickle.load(hs) - with open("cell_state.pickle", "rb") as cs: - cell_state = pickle.load(cs) + # with open("hidden_state.pickle", "rb") as hs: + # hidden_state = pickle.load(hs) + # with open("cell_state.pickle", "rb") as cs: + # cell_state = pickle.load(cs) for pathA, pathB in zip(data_no_punc_paths, data_paths): with open(pathA, mode="r", encoding='utf-8') as file: with open(pathB, mode="r", encoding='utf-8') as file2: @@ -156,7 +175,7 @@ elif mode == "evaluate": mark_y = find_interpunction(x[-1], classes) x = torch.tensor(x, dtype=torch.float).to(device) - output, (hidden_state, cell_state) = model.forward(x, hidden_state, cell_state) + output, (hidden_state, cell_state) = model.forward(x) output = output.cpu() output = output.detach().numpy() output = np.mean(output, axis=0) @@ -164,7 +183,7 @@ elif mode == "evaluate": result_index = np.argmax(output) # if output[result_index] < threshold: # incorrect += 1 - + print(output) if len(mark_y) > 0: if classes[np.argmax(output)] == mark_y: correct += 1 diff --git a/util.py b/util.py index 34e797f..0cbcfbd 100644 --- a/util.py +++ b/util.py @@ -20,11 +20,11 @@ class Model(torch.nn.Module): self.dense2 = torch.nn.Linear(300, 8) self.softmax = torch.nn.Softmax(dim=0) - def forward(self, data, hidden_state, cell_state): + def forward(self, data): data = self.dense1(data.T) data = self.tanh1(data) - data, (hidden_state, cell_state) = self.lstm(data.unsqueeze(1), (hidden_state, cell_state)) - # data, (hidden_state, cell_state) = self.lstm(data.unsqueeze(1)) + # data, (hidden_state, cell_state) = self.lstm(data.unsqueeze(1), (hidden_state, cell_state)) + data, (hidden_state, cell_state) = self.lstm(data.unsqueeze(1)) data = self.dense2(data) data = self.softmax(data) return data, (hidden_state, cell_state) \ No newline at end of file