diff --git a/main.py b/main.py index 0949e3e..d742fe1 100644 --- a/main.py +++ b/main.py @@ -98,7 +98,7 @@ if mode == "train": for path in tqdm.tqdm(data_paths): with open(path, mode="r", encoding="utf-8") as file: list = file.readlines() - for i in range(0, 10): + for i in range(0, len(list) - context_size - 1 - 1): model.zero_grad() x = list[i: i + context_size] x = [line2word(y) for y in x] @@ -135,6 +135,10 @@ elif mode == "evaluate": threshold = 0.3 model.load_state_dict(torch.load("model-0.pt")) model.eval() + with open("hidden_state.pickle", "rb") as hs: + hidden_state = pickle.load(hs) + with open("cell_state.pickle", "rb") as cs: + cell_state = pickle.load(cs) for pathA, pathB in zip(data_no_punc_paths, data_paths): with open(pathA, mode="r", encoding='utf-8') as file: with open(pathB, mode="r", encoding='utf-8') as file2: @@ -155,9 +159,12 @@ elif mode == "evaluate": output, (hidden_state, cell_state) = model.forward(x, hidden_state, cell_state) output = output.cpu() output = output.detach().numpy() + output = np.mean(output, axis=0) + output = np.squeeze(output) result_index = np.argmax(output) - if output[result_index] < threshold and len(mark_y) == 0: - correct += 1 + # if output[result_index] < threshold: + # incorrect += 1 + if len(mark_y) > 0: if classes[np.argmax(output)] == mark_y: correct += 1