123
This commit is contained in:
parent
40aa8aa379
commit
808efa0aad
4
.gitignore
vendored
4
.gitignore
vendored
@ -7,4 +7,6 @@
|
|||||||
.DS_Store
|
.DS_Store
|
||||||
.token
|
.token
|
||||||
.vscode
|
.vscode
|
||||||
fa/*
|
fa/*
|
||||||
|
*.pt
|
||||||
|
*.pickle
|
80
main.py
80
main.py
@ -4,6 +4,7 @@ import spacy
|
|||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import tqdm
|
import tqdm
|
||||||
|
import pickle
|
||||||
|
|
||||||
def clean_string(str):
|
def clean_string(str):
|
||||||
str = str.replace('\n', '')
|
str = str.replace('\n', '')
|
||||||
@ -36,7 +37,7 @@ def compute_class_vector(mark, classes):
|
|||||||
result = np.zeros(len(classes))
|
result = np.zeros(len(classes))
|
||||||
for x in range(len(classes)):
|
for x in range(len(classes)):
|
||||||
if classes[x] == mark[0]:
|
if classes[x] == mark[0]:
|
||||||
result[x] == 1
|
result[x] = 1.0
|
||||||
return torch.tensor(result, dtype=torch.long)
|
return torch.tensor(result, dtype=torch.long)
|
||||||
|
|
||||||
def prepare_input(index, data, context_size):
|
def prepare_input(index, data, context_size):
|
||||||
@ -65,7 +66,10 @@ mode = "train"
|
|||||||
data_paths = os.listdir(data_dir)
|
data_paths = os.listdir(data_dir)
|
||||||
data_paths = [data_dir + "/" + x for x in data_paths]
|
data_paths = [data_dir + "/" + x for x in data_paths]
|
||||||
|
|
||||||
classes = [',', '.', '?', '!', '-', ':', '...']
|
data_no_punc_paths = os.listdir(data_nopunc_dir)
|
||||||
|
data_no_punc_paths = [data_nopunc_dir + "/" + x for x in data_no_punc_paths]
|
||||||
|
|
||||||
|
classes = [',', '.', '?', '!', '-', ':', '...', '']
|
||||||
nlp = spacy.load("pl_core_news_sm")
|
nlp = spacy.load("pl_core_news_sm")
|
||||||
context_size = 5
|
context_size = 5
|
||||||
|
|
||||||
@ -93,8 +97,8 @@ if mode == "train":
|
|||||||
for epoch in range(epochs):
|
for epoch in range(epochs):
|
||||||
for path in tqdm.tqdm(data_paths):
|
for path in tqdm.tqdm(data_paths):
|
||||||
with open(path, mode="r", encoding="utf-8") as file:
|
with open(path, mode="r", encoding="utf-8") as file:
|
||||||
list = file.readlines()[:-1]
|
list = file.readlines()
|
||||||
for i in range(0, len(list) - context_size - 1):
|
for i in range(0, 10):
|
||||||
model.zero_grad()
|
model.zero_grad()
|
||||||
x = list[i: i + context_size]
|
x = list[i: i + context_size]
|
||||||
x = [line2word(y) for y in x]
|
x = [line2word(y) for y in x]
|
||||||
@ -109,7 +113,7 @@ if mode == "train":
|
|||||||
x = x.to(device)
|
x = x.to(device)
|
||||||
output, (hidden_state, cell_state) = model.forward(x, hidden_state, cell_state)
|
output, (hidden_state, cell_state) = model.forward(x, hidden_state, cell_state)
|
||||||
output = output.squeeze(1)
|
output = output.squeeze(1)
|
||||||
loss = loss_function(output, compute_class_vector(mark, classes).to(device))
|
loss = loss_function(output, compute_class_vector([mark], classes).to(device))
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
hidden_state = hidden_state.detach()
|
hidden_state = hidden_state.detach()
|
||||||
@ -120,43 +124,47 @@ if mode == "train":
|
|||||||
model.state_dict(),
|
model.state_dict(),
|
||||||
os.path.join("./", f"{output_prefix}-{epoch}.pt"),
|
os.path.join("./", f"{output_prefix}-{epoch}.pt"),
|
||||||
)
|
)
|
||||||
|
with open("hidden_state.pickle", "wb") as hs:
|
||||||
|
pickle.dump(hidden_state, hs)
|
||||||
|
with open("cell_state.pickle", "wb") as cs:
|
||||||
|
pickle.dump(cell_state, cs)
|
||||||
|
|
||||||
elif mode == "evaluate":
|
elif mode == "evaluate":
|
||||||
correct = 0
|
correct = 0
|
||||||
incorrect = 0
|
incorrect = 0
|
||||||
threshold = 0.3
|
threshold = 0.3
|
||||||
for pathA, pathB in zip(data_nopunc_dir, data_dir):
|
model.load_state_dict(torch.load("model-0.pt"))
|
||||||
listA = []
|
model.eval()
|
||||||
listB = []
|
for pathA, pathB in zip(data_no_punc_paths, data_paths):
|
||||||
with open(pathA, mode="r", encoding='utf-8') as file:
|
with open(pathA, mode="r", encoding='utf-8') as file:
|
||||||
listA = file.readlines()[:-1]
|
with open(pathB, mode="r", encoding='utf-8') as file2:
|
||||||
with open(pathB, mode="r", encoding='utf-8') as file:
|
listA = file.readlines()[:-1]
|
||||||
listb = file.readlines()[:-1]
|
listB = file2.readlines()[:-1]
|
||||||
for i in range(0, len(listA) - context_size - 1):
|
for i in range(0, len(listA) - context_size - 1):
|
||||||
model.zero_grad()
|
model.zero_grad()
|
||||||
x = listA[i: i + context_size]
|
x = listA[i: i + context_size]
|
||||||
x = [line2word(y) for y in x]
|
x = [line2word(y) for y in x]
|
||||||
x_1 = [line2word(listA[i + context_size + 1])]
|
x_1 = [line2word(listA[i + context_size + 1])]
|
||||||
|
|
||||||
x = x + x_1
|
x = x + x_1
|
||||||
x = words_to_vecs(x)
|
x = words_to_vecs(x)
|
||||||
|
|
||||||
y = listB[i + context_size]
|
mark_y = find_interpunction(x[-1], classes)
|
||||||
y = [line2word(x) for x in y]
|
x = torch.tensor(x, dtype=torch.float).to(device)
|
||||||
mark_y = find_interpunction(y)
|
|
||||||
x = torch.tensor(x, dtype=torch.float)
|
|
||||||
|
|
||||||
output, (hidden_state, cell_state) = model.forward(x, hidden_state, cell_state)
|
output, (hidden_state, cell_state) = model.forward(x, hidden_state, cell_state)
|
||||||
result_index = np.argmax(output)
|
output = output.cpu()
|
||||||
if output[result_index] < threshold and len(mark_y) == 0:
|
output = output.detach().numpy()
|
||||||
correct += 1
|
result_index = np.argmax(output)
|
||||||
if len(mark_y) > 0:
|
if output[result_index] < threshold and len(mark_y) == 0:
|
||||||
if classes[np.argmax(output)] == mark_y:
|
correct += 1
|
||||||
correct += 1
|
if len(mark_y) > 0:
|
||||||
else:
|
if classes[np.argmax(output)] == mark_y:
|
||||||
incorrect += 1
|
correct += 1
|
||||||
else:
|
else:
|
||||||
incorrect += 1
|
incorrect += 1
|
||||||
|
else:
|
||||||
accuracy = correct / (correct + incorrect)
|
incorrect += 1
|
||||||
print(accuracy)
|
|
||||||
|
accuracy = correct / (correct + incorrect)
|
||||||
|
print(accuracy)
|
5
util.py
5
util.py
@ -17,13 +17,14 @@ class Model(torch.nn.Module):
|
|||||||
2 num layers
|
2 num layers
|
||||||
"""
|
"""
|
||||||
self.lstm = torch.nn.LSTM(150, 300, 2)
|
self.lstm = torch.nn.LSTM(150, 300, 2)
|
||||||
self.dense2 = torch.nn.Linear(300, 7)
|
self.dense2 = torch.nn.Linear(300, 8)
|
||||||
self.softmax = torch.nn.Softmax(dim=1)
|
self.softmax = torch.nn.Softmax(dim=0)
|
||||||
|
|
||||||
def forward(self, data, hidden_state, cell_state):
|
def forward(self, data, hidden_state, cell_state):
|
||||||
data = self.dense1(data.T)
|
data = self.dense1(data.T)
|
||||||
data = self.tanh1(data)
|
data = self.tanh1(data)
|
||||||
data, (hidden_state, cell_state) = self.lstm(data.unsqueeze(1), (hidden_state, cell_state))
|
data, (hidden_state, cell_state) = self.lstm(data.unsqueeze(1), (hidden_state, cell_state))
|
||||||
|
# data, (hidden_state, cell_state) = self.lstm(data.unsqueeze(1))
|
||||||
data = self.dense2(data)
|
data = self.dense2(data)
|
||||||
data = self.softmax(data)
|
data = self.softmax(data)
|
||||||
return data, (hidden_state, cell_state)
|
return data, (hidden_state, cell_state)
|
Loading…
Reference in New Issue
Block a user