punctuation_restoration/main.py

96 lines
2.8 KiB
Python
Raw Normal View History

2021-06-16 12:51:01 +02:00
import os
from util import Model
import spacy
import torch
import numpy as np
import tqdm
def clean_string(str):
str = str.replace('\n', '')
return str
def extract_word(line):
return line.split(" ")[1]
def line2word(line):
word = extract_word(line)
word = clean_string(word)
return word
def find_interpunction(line, classes):
result = [x for x in classes if x in line]
if len(result) > 0:
return result[0]
else:
return ['']
def words_to_vecs(list_of_words):
return [nlp(x).vector for x in list_of_words]
def softXEnt(input, target):
m = torch.nn.LogSoftmax(dim = 1)
logprobs = m(input)
return -(target * logprobs).sum() / input.shape[0]
def compute_class_vector(mark, classes):
result = np.zeros(len(classes))
for x in range(len(classes)):
if classes[x] == mark[0]:
result[x] == 1
return torch.tensor(result, dtype=torch.long)
data_dir = "./fa/poleval_final_dataset/train"
data_nopunc_dir = "./fa/poleval_final_dataset1/train"
data_paths = os.listdir(data_dir)
data_paths = [data_dir + "/" + x for x in data_paths]
classes = [',', '.', '?', '!', '-', ':', '...']
nlp = spacy.load("pl_core_news_sm")
context_size = 5
model = Model()
epochs = 5
output_prefix = "model"
hidden_state = torch.randn((2, 1, 300), requires_grad=True)
cell_state = torch.randn((2, 1, 300), requires_grad=True)
model.train()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.02)
loss_function = softXEnt
for epoch in range(epochs):
for path in tqdm.tqdm(data_paths):
with open(path, "r") as file:
list = file.readlines()[:-1]
for i in range(0, len(list) - context_size - 1):
model.zero_grad()
x = list[i: i + context_size]
x = [line2word(y) for y in x]
x_1 = [line2word(list[i + context_size + 1])]
x = x + x_1
x = words_to_vecs(x)
mark = find_interpunction(x, classes)
mark = words_to_vecs(mark)
x = torch.tensor(x, dtype=torch.float)
mark = torch.tensor(mark, dtype=torch.float)
output, (hidden_state, cell_state) = model.forward(x, hidden_state, cell_state)
output = output.squeeze(1)
loss = loss_function(output, compute_class_vector(mark, classes))
loss.backward()
optimizer.step()
hidden_state = hidden_state.detach()
cell_state = cell_state.detach()
"""
vector -> (96,), np nadarray
"""
print("Epoch: {}".format(epoch))
torch.save(
model.state_dict(),
os.path.join("./", f"{output_prefix}-{epoch}.pt"),
)