This commit is contained in:
wangobango 2021-06-17 17:54:10 +02:00
parent d013ea535d
commit a894711f00
2 changed files with 46 additions and 27 deletions

67
main.py
View File

@ -1,4 +1,5 @@
import os import os
from typing import Counter
from util import Model from util import Model
import spacy import spacy
import torch import torch
@ -38,7 +39,7 @@ def compute_class_vector(mark, classes):
for x in range(len(classes)): for x in range(len(classes)):
if classes[x] == mark[0]: if classes[x] == mark[0]:
result[x] = 1.0 result[x] = 1.0
return torch.tensor(result, dtype=torch.long) return torch.tensor(result, dtype=torch.float)
def prepare_input(index, data, context_size): def prepare_input(index, data, context_size):
x = data[index: index + context_size] x = data[index: index + context_size]
@ -76,28 +77,32 @@ context_size = 5
model = Model() model = Model()
epochs = 5 epochs = 5
output_prefix = "model" output_prefix = "model"
train_loss_acc = 30
device=torch.device("cuda") device=torch.device("cuda")
model = model.cuda() model = model.cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 15, 0.0001)
loss_function = torch.nn.MSELoss()
model.train() model.train()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.02)
loss_function = softXEnt
""" """
TODO TODO
1) dodać przetwarzanie baczowe 1) dodać przetwarzanie baczowe
2) dodać osobną sieć w pełni połączoną która używa dźwięku żeby wykrywać czy użyć interpunkcji czy nie 2) zmienić loss function
""" """
hidden_state = torch.randn((2, 1, 300), requires_grad=True).to(device) # hidden_state = torch.randn((2, 1, 300), requires_grad=True).to(device)
cell_state = torch.randn((2, 1, 300), requires_grad=True).to(device) # cell_state = torch.randn((2, 1, 300), requires_grad=True).to(device)
counter = 0
# model.load_state_dict(torch.load("model-4.pt"))
if mode == "train": if mode == "train":
for epoch in range(epochs): for epoch in range(epochs):
for path in tqdm.tqdm(data_paths): for path in tqdm.tqdm(data_paths[:20]):
with open(path, mode="r", encoding="utf-8") as file: with open(path, mode="r", encoding="utf-8") as file:
list = file.readlines() list = file.readlines()[:100]
for i in range(0, len(list) - context_size - 1 - 1): for i in range(0, len(list) - context_size - 1 - 1):
model.zero_grad() model.zero_grad()
x = list[i: i + context_size] x = list[i: i + context_size]
@ -105,40 +110,54 @@ if mode == "train":
x_1 = [line2word(list[i + context_size + 1])] x_1 = [line2word(list[i + context_size + 1])]
mark = find_interpunction(x[-1], classes) mark = find_interpunction(x[-1], classes)
if mark == '':
continue
x = x + x_1 x = x + x_1
x = words_to_vecs(x) x = words_to_vecs(x)
x = torch.tensor(x, dtype=torch.float) x = torch.tensor(x, dtype=torch.float)
x = x.to(device) x = x.to(device)
output, (hidden_state, cell_state) = model.forward(x, hidden_state, cell_state) output, (_,_) = model.forward(x)
output = output.squeeze(1) output = output.squeeze(1)
loss = loss_function(output, compute_class_vector([mark], classes).to(device)) class_vector = compute_class_vector([mark], classes).to(device)
loss = loss_function(torch.mean(output, 0), class_vector)
if counter % 10 == 0:
print(torch.mean(output, 0))
print(loss)
print(class_vector)
loss.backward() loss.backward()
optimizer.step() if counter % train_loss_acc == 0:
hidden_state = hidden_state.detach() scheduler.step()
cell_state = cell_state.detach() optimizer.step()
optimizer.zero_grad()
model.zero_grad()
counter += 1
print("Epoch: {}".format(epoch)) print("Epoch: {}".format(epoch))
torch.save( torch.save(
model.state_dict(), model.state_dict(),
os.path.join("./", f"{output_prefix}-{epoch}.pt"), os.path.join("./", f"{output_prefix}-{epoch}.pt"),
) )
with open("hidden_state.pickle", "wb") as hs: # with open("hidden_state.pickle", "wb") as hs:
pickle.dump(hidden_state, hs) # pickle.dump(hidden_state, hs)
with open("cell_state.pickle", "wb") as cs: # with open("cell_state.pickle", "wb") as cs:
pickle.dump(cell_state, cs) # pickle.dump(cell_state, cs)
elif mode == "evaluate": elif mode == "evaluate":
correct = 0 correct = 0
incorrect = 0 incorrect = 0
threshold = 0.3 threshold = 0.3
model.load_state_dict(torch.load("model-0.pt")) model.load_state_dict(torch.load("model-4.pt"))
model.eval() model.eval()
with open("hidden_state.pickle", "rb") as hs: # with open("hidden_state.pickle", "rb") as hs:
hidden_state = pickle.load(hs) # hidden_state = pickle.load(hs)
with open("cell_state.pickle", "rb") as cs: # with open("cell_state.pickle", "rb") as cs:
cell_state = pickle.load(cs) # cell_state = pickle.load(cs)
for pathA, pathB in zip(data_no_punc_paths, data_paths): for pathA, pathB in zip(data_no_punc_paths, data_paths):
with open(pathA, mode="r", encoding='utf-8') as file: with open(pathA, mode="r", encoding='utf-8') as file:
with open(pathB, mode="r", encoding='utf-8') as file2: with open(pathB, mode="r", encoding='utf-8') as file2:
@ -156,7 +175,7 @@ elif mode == "evaluate":
mark_y = find_interpunction(x[-1], classes) mark_y = find_interpunction(x[-1], classes)
x = torch.tensor(x, dtype=torch.float).to(device) x = torch.tensor(x, dtype=torch.float).to(device)
output, (hidden_state, cell_state) = model.forward(x, hidden_state, cell_state) output, (hidden_state, cell_state) = model.forward(x)
output = output.cpu() output = output.cpu()
output = output.detach().numpy() output = output.detach().numpy()
output = np.mean(output, axis=0) output = np.mean(output, axis=0)
@ -164,7 +183,7 @@ elif mode == "evaluate":
result_index = np.argmax(output) result_index = np.argmax(output)
# if output[result_index] < threshold: # if output[result_index] < threshold:
# incorrect += 1 # incorrect += 1
print(output)
if len(mark_y) > 0: if len(mark_y) > 0:
if classes[np.argmax(output)] == mark_y: if classes[np.argmax(output)] == mark_y:
correct += 1 correct += 1

View File

@ -20,11 +20,11 @@ class Model(torch.nn.Module):
self.dense2 = torch.nn.Linear(300, 8) self.dense2 = torch.nn.Linear(300, 8)
self.softmax = torch.nn.Softmax(dim=0) self.softmax = torch.nn.Softmax(dim=0)
def forward(self, data, hidden_state, cell_state): def forward(self, data):
data = self.dense1(data.T) data = self.dense1(data.T)
data = self.tanh1(data) data = self.tanh1(data)
data, (hidden_state, cell_state) = self.lstm(data.unsqueeze(1), (hidden_state, cell_state)) # data, (hidden_state, cell_state) = self.lstm(data.unsqueeze(1), (hidden_state, cell_state))
# data, (hidden_state, cell_state) = self.lstm(data.unsqueeze(1)) data, (hidden_state, cell_state) = self.lstm(data.unsqueeze(1))
data = self.dense2(data) data = self.dense2(data)
data = self.softmax(data) data = self.softmax(data)
return data, (hidden_state, cell_state) return data, (hidden_state, cell_state)