From 6fde30cb6e612a4f9d8e505e44c5160dac3ce32a Mon Sep 17 00:00:00 2001 From: wangobango Date: Wed, 16 Jun 2021 13:03:28 +0200 Subject: [PATCH] fix --- main.py | 95 +++++++++++++++++++++++++++++++++++++-------------------- 1 file changed, 62 insertions(+), 33 deletions(-) diff --git a/main.py b/main.py index 80086d4..0bddde9 100644 --- a/main.py +++ b/main.py @@ -28,8 +28,8 @@ def words_to_vecs(list_of_words): return [nlp(x).vector for x in list_of_words] def softXEnt(input, target): - m = torch.nn.LogSoftmax(dim = 1) - logprobs = m(input) + m = torch.nn.LogSoftmax() + logprobs = m(input, dim=1) return -(target * logprobs).sum() / input.shape[0] def compute_class_vector(mark, classes): @@ -43,6 +43,9 @@ def compute_class_vector(mark, classes): data_dir = "./fa/poleval_final_dataset/train" data_nopunc_dir = "./fa/poleval_final_dataset1/train" +mode = "train" +# mode = "evaluate" +# mode = "generate" data_paths = os.listdir(data_dir) data_paths = [data_dir + "/" + x for x in data_paths] @@ -61,36 +64,62 @@ model.train() optimizer = torch.optim.AdamW(model.parameters(), lr=0.02) loss_function = softXEnt -for epoch in range(epochs): - for path in tqdm.tqdm(data_paths): - with open(path, "r") as file: - list = file.readlines()[:-1] - for i in range(0, len(list) - context_size - 1): - model.zero_grad() - x = list[i: i + context_size] - x = [line2word(y) for y in x] - x_1 = [line2word(list[i + context_size + 1])] - x = x + x_1 - x = words_to_vecs(x) - mark = find_interpunction(x, classes) - mark = words_to_vecs(mark) - - x = torch.tensor(x, dtype=torch.float) - mark = torch.tensor(mark, dtype=torch.float) - output, (hidden_state, cell_state) = model.forward(x, hidden_state, cell_state) - output = output.squeeze(1) - loss = loss_function(output, compute_class_vector(mark, classes)) - loss.backward() - optimizer.step() - hidden_state = hidden_state.detach() - cell_state = cell_state.detach() +if mode == "train": + for epoch in range(epochs): + for path in tqdm.tqdm(data_paths): + with open(path, "r") as file: + list = file.readlines()[:-1] + for i in range(0, len(list) - context_size - 1): + model.zero_grad() + x = list[i: i + context_size] + x = [line2word(y) for y in x] + x_1 = [line2word(list[i + context_size + 1])] + mark = find_interpunction(x[-1], classes) + # mark = words_to_vecs(mark) - """ - vector -> (96,), np nadarray - """ - print("Epoch: {}".format(epoch)) - torch.save( - model.state_dict(), - os.path.join("./", f"{output_prefix}-{epoch}.pt"), - ) \ No newline at end of file + x = x + x_1 + x = words_to_vecs(x) + + x = torch.tensor(x, dtype=torch.float) + # mark = torch.tensor(mark, dtype=torch.float) + + output, (hidden_state, cell_state) = model.forward(x, hidden_state, cell_state) + output = output.squeeze(1) + loss = loss_function(output, compute_class_vector(mark, classes)) + loss.backward() + optimizer.step() + hidden_state = hidden_state.detach() + cell_state = cell_state.detach() + + print("Epoch: {}".format(epoch)) + torch.save( + model.state_dict(), + os.path.join("./", f"{output_prefix}-{epoch}.pt"), + ) + +elif mode == "evaluate": + for pathA, pathB in zip(data_nopunc_dir, data_dir): + listA = [] + listB = [] + with open(pathA, "r") as file: + listA = file.readlines()[:-1] + with open(pathA, "r") as file: + listb = file.readlines()[:-1] + for i in range(0, len(list) - context_size - 1): + model.zero_grad() + x = listA[i: i + context_size] + x = [line2word(y) for y in x] + x_1 = [line2word(listA[i + context_size + 1])] + + x = x + x_1 + x = words_to_vecs(x) + + y = listB[i + context_size] + y = [line2word(x) for x in y] + mark_y = find_interpunction(y) + x = torch.tensor(x, dtype=torch.float) + + output, (hidden_state, cell_state) = model.forward(x, hidden_state, cell_state) + if classes[np.argmax(output)] == mark_y: + print('dupa') \ No newline at end of file