working eval

This commit is contained in:
wangobango 2021-06-16 23:05:03 +02:00
parent 808efa0aad
commit d013ea535d

13
main.py
View File

@ -98,7 +98,7 @@ if mode == "train":
for path in tqdm.tqdm(data_paths): for path in tqdm.tqdm(data_paths):
with open(path, mode="r", encoding="utf-8") as file: with open(path, mode="r", encoding="utf-8") as file:
list = file.readlines() list = file.readlines()
for i in range(0, 10): for i in range(0, len(list) - context_size - 1 - 1):
model.zero_grad() model.zero_grad()
x = list[i: i + context_size] x = list[i: i + context_size]
x = [line2word(y) for y in x] x = [line2word(y) for y in x]
@ -135,6 +135,10 @@ elif mode == "evaluate":
threshold = 0.3 threshold = 0.3
model.load_state_dict(torch.load("model-0.pt")) model.load_state_dict(torch.load("model-0.pt"))
model.eval() model.eval()
with open("hidden_state.pickle", "rb") as hs:
hidden_state = pickle.load(hs)
with open("cell_state.pickle", "rb") as cs:
cell_state = pickle.load(cs)
for pathA, pathB in zip(data_no_punc_paths, data_paths): for pathA, pathB in zip(data_no_punc_paths, data_paths):
with open(pathA, mode="r", encoding='utf-8') as file: with open(pathA, mode="r", encoding='utf-8') as file:
with open(pathB, mode="r", encoding='utf-8') as file2: with open(pathB, mode="r", encoding='utf-8') as file2:
@ -155,9 +159,12 @@ elif mode == "evaluate":
output, (hidden_state, cell_state) = model.forward(x, hidden_state, cell_state) output, (hidden_state, cell_state) = model.forward(x, hidden_state, cell_state)
output = output.cpu() output = output.cpu()
output = output.detach().numpy() output = output.detach().numpy()
output = np.mean(output, axis=0)
output = np.squeeze(output)
result_index = np.argmax(output) result_index = np.argmax(output)
if output[result_index] < threshold and len(mark_y) == 0: # if output[result_index] < threshold:
correct += 1 # incorrect += 1
if len(mark_y) > 0: if len(mark_y) > 0:
if classes[np.argmax(output)] == mark_y: if classes[np.argmax(output)] == mark_y:
correct += 1 correct += 1