diff --git a/.gitignore b/.gitignore index 1c18d74..4318506 100644 --- a/.gitignore +++ b/.gitignore @@ -6,3 +6,5 @@ *.o .DS_Store .token +.vscode +fa/* \ No newline at end of file diff --git a/main.py b/main.py new file mode 100644 index 0000000..80086d4 --- /dev/null +++ b/main.py @@ -0,0 +1,96 @@ +import os +from util import Model +import spacy +import torch +import numpy as np +import tqdm + +def clean_string(str): + str = str.replace('\n', '') + return str + +def extract_word(line): + return line.split(" ")[1] + +def line2word(line): + word = extract_word(line) + word = clean_string(word) + return word + +def find_interpunction(line, classes): + result = [x for x in classes if x in line] + if len(result) > 0: + return result[0] + else: + return [''] + +def words_to_vecs(list_of_words): + return [nlp(x).vector for x in list_of_words] + +def softXEnt(input, target): + m = torch.nn.LogSoftmax(dim = 1) + logprobs = m(input) + return -(target * logprobs).sum() / input.shape[0] + +def compute_class_vector(mark, classes): + result = np.zeros(len(classes)) + for x in range(len(classes)): + if classes[x] == mark[0]: + result[x] == 1 + return torch.tensor(result, dtype=torch.long) + + + +data_dir = "./fa/poleval_final_dataset/train" +data_nopunc_dir = "./fa/poleval_final_dataset1/train" + +data_paths = os.listdir(data_dir) +data_paths = [data_dir + "/" + x for x in data_paths] + +classes = [',', '.', '?', '!', '-', ':', '...'] +nlp = spacy.load("pl_core_news_sm") +context_size = 5 + +model = Model() +epochs = 5 +output_prefix = "model" +hidden_state = torch.randn((2, 1, 300), requires_grad=True) +cell_state = torch.randn((2, 1, 300), requires_grad=True) + +model.train() +optimizer = torch.optim.AdamW(model.parameters(), lr=0.02) +loss_function = softXEnt + +for epoch in range(epochs): + for path in tqdm.tqdm(data_paths): + with open(path, "r") as file: + list = file.readlines()[:-1] + for i in range(0, len(list) - context_size - 1): + model.zero_grad() + x = list[i: i + context_size] + x = [line2word(y) for y in x] + x_1 = [line2word(list[i + context_size + 1])] + x = x + x_1 + x = words_to_vecs(x) + mark = find_interpunction(x, classes) + mark = words_to_vecs(mark) + + x = torch.tensor(x, dtype=torch.float) + mark = torch.tensor(mark, dtype=torch.float) + + output, (hidden_state, cell_state) = model.forward(x, hidden_state, cell_state) + output = output.squeeze(1) + loss = loss_function(output, compute_class_vector(mark, classes)) + loss.backward() + optimizer.step() + hidden_state = hidden_state.detach() + cell_state = cell_state.detach() + + """ + vector -> (96,), np nadarray + """ + print("Epoch: {}".format(epoch)) + torch.save( + model.state_dict(), + os.path.join("./", f"{output_prefix}-{epoch}.pt"), + ) \ No newline at end of file diff --git a/util.py b/util.py new file mode 100644 index 0000000..6199b86 --- /dev/null +++ b/util.py @@ -0,0 +1,29 @@ +import torch + +class Model(torch.nn.Module): + + def __init__(self): + super(Model, self).__init__() + # in: 4 słowa kontekstu przed i 1 słowo kontekstu po + """ + 5 in features + 150 out features + """ + self.dense1 = torch.nn.Linear(6, 150, bias=False) + self.tanh1 = torch.nn.Tanh() + """ + 150 in features + 300 hidden values + 2 num layers + """ + self.lstm = torch.nn.LSTM(150, 300, 2) + self.dense2 = torch.nn.Linear(300, 7) + self.softmax = torch.nn.Softmax() + + def forward(self, data, hidden_state, cell_state): + data = self.dense1(data.T) + data = self.tanh1(data) + data, (hidden_state, cell_state) = self.lstm(data.unsqueeze(1), (hidden_state, cell_state)) + data = self.dense2(data) + data = self.softmax(data) + return data, (hidden_state, cell_state) \ No newline at end of file