punctuation_restoration/main.py

196 lines
6.3 KiB
Python
Raw Permalink Normal View History

2021-06-16 12:51:01 +02:00
import os
2021-06-17 17:54:10 +02:00
from typing import Counter
2021-06-16 12:51:01 +02:00
from util import Model
import spacy
import torch
import numpy as np
import tqdm
2021-06-16 22:23:55 +02:00
import pickle
2021-06-16 12:51:01 +02:00
def clean_string(str):
str = str.replace('\n', '')
return str
def extract_word(line):
return line.split(" ")[1]
def line2word(line):
word = extract_word(line)
word = clean_string(word)
return word
def find_interpunction(line, classes):
result = [x for x in classes if x in line]
if len(result) > 0:
return result[0]
else:
return ['']
def words_to_vecs(list_of_words):
return [nlp(x).vector for x in list_of_words]
def softXEnt(input, target):
2021-06-16 16:04:38 +02:00
m = torch.nn.LogSoftmax(dim=1)
logprobs = m(input)
2021-06-16 12:51:01 +02:00
return -(target * logprobs).sum() / input.shape[0]
def compute_class_vector(mark, classes):
result = np.zeros(len(classes))
for x in range(len(classes)):
if classes[x] == mark[0]:
2021-06-16 22:23:55 +02:00
result[x] = 1.0
2021-06-17 17:54:10 +02:00
return torch.tensor(result, dtype=torch.float)
2021-06-16 12:51:01 +02:00
2021-06-16 19:04:43 +02:00
def prepare_input(index, data, context_size):
x = data[index: index + context_size]
x = [line2word(y) for y in x]
x_1 = [line2word(list[index + context_size + 1])]
mark = find_interpunction(x[-1], classes)
x = x + x_1
x = words_to_vecs(x)
x = torch.tensor(x, dtype=torch.float)
return x, compute_class_vector(mark, classes)
def prepare_batch(index, data, context_size, batch_size):
result = []
for i in range(index, index + batch_size):
result.append(prepare_input(i, data, context_size))
return result
2021-06-16 12:51:01 +02:00
data_dir = "./fa/poleval_final_dataset/train"
data_nopunc_dir = "./fa/poleval_final_dataset1/train"
2021-06-16 13:03:28 +02:00
mode = "train"
# mode = "evaluate"
# mode = "generate"
2021-06-16 12:51:01 +02:00
data_paths = os.listdir(data_dir)
data_paths = [data_dir + "/" + x for x in data_paths]
2021-06-16 22:23:55 +02:00
data_no_punc_paths = os.listdir(data_nopunc_dir)
data_no_punc_paths = [data_nopunc_dir + "/" + x for x in data_no_punc_paths]
classes = [',', '.', '?', '!', '-', ':', '...', '']
2021-06-16 12:51:01 +02:00
nlp = spacy.load("pl_core_news_sm")
context_size = 5
model = Model()
epochs = 5
output_prefix = "model"
2021-06-17 17:54:10 +02:00
train_loss_acc = 30
2021-06-16 12:51:01 +02:00
2021-06-16 19:04:43 +02:00
device=torch.device("cuda")
model = model.cuda()
2021-06-17 17:54:10 +02:00
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 15, 0.0001)
loss_function = torch.nn.MSELoss()
2021-06-16 12:51:01 +02:00
model.train()
2021-06-16 16:04:38 +02:00
"""
TODO
2021-06-16 19:04:43 +02:00
1) dodać przetwarzanie baczowe
2021-06-17 17:54:10 +02:00
2) zmienić loss function
2021-06-16 16:04:38 +02:00
"""
2021-06-16 13:03:28 +02:00
2021-06-16 19:04:43 +02:00
2021-06-17 17:54:10 +02:00
# hidden_state = torch.randn((2, 1, 300), requires_grad=True).to(device)
# cell_state = torch.randn((2, 1, 300), requires_grad=True).to(device)
counter = 0
# model.load_state_dict(torch.load("model-4.pt"))
2021-06-16 19:04:43 +02:00
2021-06-16 13:03:28 +02:00
if mode == "train":
for epoch in range(epochs):
2021-06-17 17:54:10 +02:00
for path in tqdm.tqdm(data_paths[:20]):
2021-06-16 19:04:43 +02:00
with open(path, mode="r", encoding="utf-8") as file:
2021-06-17 17:54:10 +02:00
list = file.readlines()[:100]
2021-06-16 23:05:03 +02:00
for i in range(0, len(list) - context_size - 1 - 1):
2021-06-16 13:03:28 +02:00
model.zero_grad()
x = list[i: i + context_size]
x = [line2word(y) for y in x]
x_1 = [line2word(list[i + context_size + 1])]
mark = find_interpunction(x[-1], classes)
2021-06-17 17:54:10 +02:00
if mark == '':
continue
2021-06-16 13:03:28 +02:00
x = x + x_1
x = words_to_vecs(x)
x = torch.tensor(x, dtype=torch.float)
2021-06-16 19:04:43 +02:00
x = x.to(device)
2021-06-17 17:54:10 +02:00
output, (_,_) = model.forward(x)
2021-06-16 13:03:28 +02:00
output = output.squeeze(1)
2021-06-17 17:54:10 +02:00
class_vector = compute_class_vector([mark], classes).to(device)
loss = loss_function(torch.mean(output, 0), class_vector)
if counter % 10 == 0:
print(torch.mean(output, 0))
print(loss)
print(class_vector)
2021-06-16 13:03:28 +02:00
loss.backward()
2021-06-17 17:54:10 +02:00
if counter % train_loss_acc == 0:
scheduler.step()
optimizer.step()
optimizer.zero_grad()
model.zero_grad()
counter += 1
2021-06-16 13:03:28 +02:00
print("Epoch: {}".format(epoch))
torch.save(
model.state_dict(),
os.path.join("./", f"{output_prefix}-{epoch}.pt"),
)
2021-06-17 17:54:10 +02:00
# with open("hidden_state.pickle", "wb") as hs:
# pickle.dump(hidden_state, hs)
# with open("cell_state.pickle", "wb") as cs:
# pickle.dump(cell_state, cs)
2021-06-16 13:03:28 +02:00
elif mode == "evaluate":
2021-06-16 19:04:43 +02:00
correct = 0
incorrect = 0
threshold = 0.3
2021-06-17 17:54:10 +02:00
model.load_state_dict(torch.load("model-4.pt"))
2021-06-16 22:23:55 +02:00
model.eval()
2021-06-17 17:54:10 +02:00
# with open("hidden_state.pickle", "rb") as hs:
# hidden_state = pickle.load(hs)
# with open("cell_state.pickle", "rb") as cs:
# cell_state = pickle.load(cs)
2021-06-16 22:23:55 +02:00
for pathA, pathB in zip(data_no_punc_paths, data_paths):
2021-06-16 19:04:43 +02:00
with open(pathA, mode="r", encoding='utf-8') as file:
2021-06-16 22:23:55 +02:00
with open(pathB, mode="r", encoding='utf-8') as file2:
listA = file.readlines()[:-1]
listB = file2.readlines()[:-1]
for i in range(0, len(listA) - context_size - 1):
model.zero_grad()
x = listA[i: i + context_size]
x = [line2word(y) for y in x]
x_1 = [line2word(listA[i + context_size + 1])]
x = x + x_1
x = words_to_vecs(x)
mark_y = find_interpunction(x[-1], classes)
x = torch.tensor(x, dtype=torch.float).to(device)
2021-06-17 17:54:10 +02:00
output, (hidden_state, cell_state) = model.forward(x)
2021-06-16 22:23:55 +02:00
output = output.cpu()
output = output.detach().numpy()
2021-06-16 23:05:03 +02:00
output = np.mean(output, axis=0)
output = np.squeeze(output)
2021-06-16 22:23:55 +02:00
result_index = np.argmax(output)
2021-06-16 23:05:03 +02:00
# if output[result_index] < threshold:
# incorrect += 1
2021-06-17 17:54:10 +02:00
print(output)
2021-06-16 22:23:55 +02:00
if len(mark_y) > 0:
if classes[np.argmax(output)] == mark_y:
correct += 1
else:
incorrect += 1
else:
incorrect += 1
accuracy = correct / (correct + incorrect)
print(accuracy)