This commit is contained in:
wangobango 2021-06-16 22:23:55 +02:00
parent 40aa8aa379
commit 808efa0aad
3 changed files with 50 additions and 39 deletions

4
.gitignore vendored
View File

@ -7,4 +7,6 @@
.DS_Store .DS_Store
.token .token
.vscode .vscode
fa/* fa/*
*.pt
*.pickle

80
main.py
View File

@ -4,6 +4,7 @@ import spacy
import torch import torch
import numpy as np import numpy as np
import tqdm import tqdm
import pickle
def clean_string(str): def clean_string(str):
str = str.replace('\n', '') str = str.replace('\n', '')
@ -36,7 +37,7 @@ def compute_class_vector(mark, classes):
result = np.zeros(len(classes)) result = np.zeros(len(classes))
for x in range(len(classes)): for x in range(len(classes)):
if classes[x] == mark[0]: if classes[x] == mark[0]:
result[x] == 1 result[x] = 1.0
return torch.tensor(result, dtype=torch.long) return torch.tensor(result, dtype=torch.long)
def prepare_input(index, data, context_size): def prepare_input(index, data, context_size):
@ -65,7 +66,10 @@ mode = "train"
data_paths = os.listdir(data_dir) data_paths = os.listdir(data_dir)
data_paths = [data_dir + "/" + x for x in data_paths] data_paths = [data_dir + "/" + x for x in data_paths]
classes = [',', '.', '?', '!', '-', ':', '...'] data_no_punc_paths = os.listdir(data_nopunc_dir)
data_no_punc_paths = [data_nopunc_dir + "/" + x for x in data_no_punc_paths]
classes = [',', '.', '?', '!', '-', ':', '...', '']
nlp = spacy.load("pl_core_news_sm") nlp = spacy.load("pl_core_news_sm")
context_size = 5 context_size = 5
@ -93,8 +97,8 @@ if mode == "train":
for epoch in range(epochs): for epoch in range(epochs):
for path in tqdm.tqdm(data_paths): for path in tqdm.tqdm(data_paths):
with open(path, mode="r", encoding="utf-8") as file: with open(path, mode="r", encoding="utf-8") as file:
list = file.readlines()[:-1] list = file.readlines()
for i in range(0, len(list) - context_size - 1): for i in range(0, 10):
model.zero_grad() model.zero_grad()
x = list[i: i + context_size] x = list[i: i + context_size]
x = [line2word(y) for y in x] x = [line2word(y) for y in x]
@ -109,7 +113,7 @@ if mode == "train":
x = x.to(device) x = x.to(device)
output, (hidden_state, cell_state) = model.forward(x, hidden_state, cell_state) output, (hidden_state, cell_state) = model.forward(x, hidden_state, cell_state)
output = output.squeeze(1) output = output.squeeze(1)
loss = loss_function(output, compute_class_vector(mark, classes).to(device)) loss = loss_function(output, compute_class_vector([mark], classes).to(device))
loss.backward() loss.backward()
optimizer.step() optimizer.step()
hidden_state = hidden_state.detach() hidden_state = hidden_state.detach()
@ -120,43 +124,47 @@ if mode == "train":
model.state_dict(), model.state_dict(),
os.path.join("./", f"{output_prefix}-{epoch}.pt"), os.path.join("./", f"{output_prefix}-{epoch}.pt"),
) )
with open("hidden_state.pickle", "wb") as hs:
pickle.dump(hidden_state, hs)
with open("cell_state.pickle", "wb") as cs:
pickle.dump(cell_state, cs)
elif mode == "evaluate": elif mode == "evaluate":
correct = 0 correct = 0
incorrect = 0 incorrect = 0
threshold = 0.3 threshold = 0.3
for pathA, pathB in zip(data_nopunc_dir, data_dir): model.load_state_dict(torch.load("model-0.pt"))
listA = [] model.eval()
listB = [] for pathA, pathB in zip(data_no_punc_paths, data_paths):
with open(pathA, mode="r", encoding='utf-8') as file: with open(pathA, mode="r", encoding='utf-8') as file:
listA = file.readlines()[:-1] with open(pathB, mode="r", encoding='utf-8') as file2:
with open(pathB, mode="r", encoding='utf-8') as file: listA = file.readlines()[:-1]
listb = file.readlines()[:-1] listB = file2.readlines()[:-1]
for i in range(0, len(listA) - context_size - 1): for i in range(0, len(listA) - context_size - 1):
model.zero_grad() model.zero_grad()
x = listA[i: i + context_size] x = listA[i: i + context_size]
x = [line2word(y) for y in x] x = [line2word(y) for y in x]
x_1 = [line2word(listA[i + context_size + 1])] x_1 = [line2word(listA[i + context_size + 1])]
x = x + x_1 x = x + x_1
x = words_to_vecs(x) x = words_to_vecs(x)
y = listB[i + context_size] mark_y = find_interpunction(x[-1], classes)
y = [line2word(x) for x in y] x = torch.tensor(x, dtype=torch.float).to(device)
mark_y = find_interpunction(y)
x = torch.tensor(x, dtype=torch.float)
output, (hidden_state, cell_state) = model.forward(x, hidden_state, cell_state) output, (hidden_state, cell_state) = model.forward(x, hidden_state, cell_state)
result_index = np.argmax(output) output = output.cpu()
if output[result_index] < threshold and len(mark_y) == 0: output = output.detach().numpy()
correct += 1 result_index = np.argmax(output)
if len(mark_y) > 0: if output[result_index] < threshold and len(mark_y) == 0:
if classes[np.argmax(output)] == mark_y: correct += 1
correct += 1 if len(mark_y) > 0:
else: if classes[np.argmax(output)] == mark_y:
incorrect += 1 correct += 1
else: else:
incorrect += 1 incorrect += 1
else:
accuracy = correct / (correct + incorrect) incorrect += 1
print(accuracy)
accuracy = correct / (correct + incorrect)
print(accuracy)

View File

@ -17,13 +17,14 @@ class Model(torch.nn.Module):
2 num layers 2 num layers
""" """
self.lstm = torch.nn.LSTM(150, 300, 2) self.lstm = torch.nn.LSTM(150, 300, 2)
self.dense2 = torch.nn.Linear(300, 7) self.dense2 = torch.nn.Linear(300, 8)
self.softmax = torch.nn.Softmax(dim=1) self.softmax = torch.nn.Softmax(dim=0)
def forward(self, data, hidden_state, cell_state): def forward(self, data, hidden_state, cell_state):
data = self.dense1(data.T) data = self.dense1(data.T)
data = self.tanh1(data) data = self.tanh1(data)
data, (hidden_state, cell_state) = self.lstm(data.unsqueeze(1), (hidden_state, cell_state)) data, (hidden_state, cell_state) = self.lstm(data.unsqueeze(1), (hidden_state, cell_state))
# data, (hidden_state, cell_state) = self.lstm(data.unsqueeze(1))
data = self.dense2(data) data = self.dense2(data)
data = self.softmax(data) data = self.softmax(data)
return data, (hidden_state, cell_state) return data, (hidden_state, cell_state)