working eval
This commit is contained in:
parent
808efa0aad
commit
d013ea535d
13
main.py
13
main.py
@ -98,7 +98,7 @@ if mode == "train":
|
|||||||
for path in tqdm.tqdm(data_paths):
|
for path in tqdm.tqdm(data_paths):
|
||||||
with open(path, mode="r", encoding="utf-8") as file:
|
with open(path, mode="r", encoding="utf-8") as file:
|
||||||
list = file.readlines()
|
list = file.readlines()
|
||||||
for i in range(0, 10):
|
for i in range(0, len(list) - context_size - 1 - 1):
|
||||||
model.zero_grad()
|
model.zero_grad()
|
||||||
x = list[i: i + context_size]
|
x = list[i: i + context_size]
|
||||||
x = [line2word(y) for y in x]
|
x = [line2word(y) for y in x]
|
||||||
@ -135,6 +135,10 @@ elif mode == "evaluate":
|
|||||||
threshold = 0.3
|
threshold = 0.3
|
||||||
model.load_state_dict(torch.load("model-0.pt"))
|
model.load_state_dict(torch.load("model-0.pt"))
|
||||||
model.eval()
|
model.eval()
|
||||||
|
with open("hidden_state.pickle", "rb") as hs:
|
||||||
|
hidden_state = pickle.load(hs)
|
||||||
|
with open("cell_state.pickle", "rb") as cs:
|
||||||
|
cell_state = pickle.load(cs)
|
||||||
for pathA, pathB in zip(data_no_punc_paths, data_paths):
|
for pathA, pathB in zip(data_no_punc_paths, data_paths):
|
||||||
with open(pathA, mode="r", encoding='utf-8') as file:
|
with open(pathA, mode="r", encoding='utf-8') as file:
|
||||||
with open(pathB, mode="r", encoding='utf-8') as file2:
|
with open(pathB, mode="r", encoding='utf-8') as file2:
|
||||||
@ -155,9 +159,12 @@ elif mode == "evaluate":
|
|||||||
output, (hidden_state, cell_state) = model.forward(x, hidden_state, cell_state)
|
output, (hidden_state, cell_state) = model.forward(x, hidden_state, cell_state)
|
||||||
output = output.cpu()
|
output = output.cpu()
|
||||||
output = output.detach().numpy()
|
output = output.detach().numpy()
|
||||||
|
output = np.mean(output, axis=0)
|
||||||
|
output = np.squeeze(output)
|
||||||
result_index = np.argmax(output)
|
result_index = np.argmax(output)
|
||||||
if output[result_index] < threshold and len(mark_y) == 0:
|
# if output[result_index] < threshold:
|
||||||
correct += 1
|
# incorrect += 1
|
||||||
|
|
||||||
if len(mark_y) > 0:
|
if len(mark_y) > 0:
|
||||||
if classes[np.argmax(output)] == mark_y:
|
if classes[np.argmax(output)] == mark_y:
|
||||||
correct += 1
|
correct += 1
|
||||||
|
Loading…
Reference in New Issue
Block a user