working learing
This commit is contained in:
parent
2f7e3f3d97
commit
f3404fc347
2
.gitignore
vendored
2
.gitignore
vendored
@ -6,3 +6,5 @@
|
||||
*.o
|
||||
.DS_Store
|
||||
.token
|
||||
.vscode
|
||||
fa/*
|
96
main.py
Normal file
96
main.py
Normal file
@ -0,0 +1,96 @@
|
||||
import os
|
||||
from util import Model
|
||||
import spacy
|
||||
import torch
|
||||
import numpy as np
|
||||
import tqdm
|
||||
|
||||
def clean_string(str):
|
||||
str = str.replace('\n', '')
|
||||
return str
|
||||
|
||||
def extract_word(line):
|
||||
return line.split(" ")[1]
|
||||
|
||||
def line2word(line):
|
||||
word = extract_word(line)
|
||||
word = clean_string(word)
|
||||
return word
|
||||
|
||||
def find_interpunction(line, classes):
|
||||
result = [x for x in classes if x in line]
|
||||
if len(result) > 0:
|
||||
return result[0]
|
||||
else:
|
||||
return ['']
|
||||
|
||||
def words_to_vecs(list_of_words):
|
||||
return [nlp(x).vector for x in list_of_words]
|
||||
|
||||
def softXEnt(input, target):
|
||||
m = torch.nn.LogSoftmax(dim = 1)
|
||||
logprobs = m(input)
|
||||
return -(target * logprobs).sum() / input.shape[0]
|
||||
|
||||
def compute_class_vector(mark, classes):
|
||||
result = np.zeros(len(classes))
|
||||
for x in range(len(classes)):
|
||||
if classes[x] == mark[0]:
|
||||
result[x] == 1
|
||||
return torch.tensor(result, dtype=torch.long)
|
||||
|
||||
|
||||
|
||||
data_dir = "./fa/poleval_final_dataset/train"
|
||||
data_nopunc_dir = "./fa/poleval_final_dataset1/train"
|
||||
|
||||
data_paths = os.listdir(data_dir)
|
||||
data_paths = [data_dir + "/" + x for x in data_paths]
|
||||
|
||||
classes = [',', '.', '?', '!', '-', ':', '...']
|
||||
nlp = spacy.load("pl_core_news_sm")
|
||||
context_size = 5
|
||||
|
||||
model = Model()
|
||||
epochs = 5
|
||||
output_prefix = "model"
|
||||
hidden_state = torch.randn((2, 1, 300), requires_grad=True)
|
||||
cell_state = torch.randn((2, 1, 300), requires_grad=True)
|
||||
|
||||
model.train()
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=0.02)
|
||||
loss_function = softXEnt
|
||||
|
||||
for epoch in range(epochs):
|
||||
for path in tqdm.tqdm(data_paths):
|
||||
with open(path, "r") as file:
|
||||
list = file.readlines()[:-1]
|
||||
for i in range(0, len(list) - context_size - 1):
|
||||
model.zero_grad()
|
||||
x = list[i: i + context_size]
|
||||
x = [line2word(y) for y in x]
|
||||
x_1 = [line2word(list[i + context_size + 1])]
|
||||
x = x + x_1
|
||||
x = words_to_vecs(x)
|
||||
mark = find_interpunction(x, classes)
|
||||
mark = words_to_vecs(mark)
|
||||
|
||||
x = torch.tensor(x, dtype=torch.float)
|
||||
mark = torch.tensor(mark, dtype=torch.float)
|
||||
|
||||
output, (hidden_state, cell_state) = model.forward(x, hidden_state, cell_state)
|
||||
output = output.squeeze(1)
|
||||
loss = loss_function(output, compute_class_vector(mark, classes))
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
hidden_state = hidden_state.detach()
|
||||
cell_state = cell_state.detach()
|
||||
|
||||
"""
|
||||
vector -> (96,), np nadarray
|
||||
"""
|
||||
print("Epoch: {}".format(epoch))
|
||||
torch.save(
|
||||
model.state_dict(),
|
||||
os.path.join("./", f"{output_prefix}-{epoch}.pt"),
|
||||
)
|
29
util.py
Normal file
29
util.py
Normal file
@ -0,0 +1,29 @@
|
||||
import torch
|
||||
|
||||
class Model(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(Model, self).__init__()
|
||||
# in: 4 słowa kontekstu przed i 1 słowo kontekstu po
|
||||
"""
|
||||
5 in features
|
||||
150 out features
|
||||
"""
|
||||
self.dense1 = torch.nn.Linear(6, 150, bias=False)
|
||||
self.tanh1 = torch.nn.Tanh()
|
||||
"""
|
||||
150 in features
|
||||
300 hidden values
|
||||
2 num layers
|
||||
"""
|
||||
self.lstm = torch.nn.LSTM(150, 300, 2)
|
||||
self.dense2 = torch.nn.Linear(300, 7)
|
||||
self.softmax = torch.nn.Softmax()
|
||||
|
||||
def forward(self, data, hidden_state, cell_state):
|
||||
data = self.dense1(data.T)
|
||||
data = self.tanh1(data)
|
||||
data, (hidden_state, cell_state) = self.lstm(data.unsqueeze(1), (hidden_state, cell_state))
|
||||
data = self.dense2(data)
|
||||
data = self.softmax(data)
|
||||
return data, (hidden_state, cell_state)
|
Loading…
Reference in New Issue
Block a user