This commit is contained in:
wangobango 2021-06-17 17:54:10 +02:00
parent d013ea535d
commit a894711f00
2 changed files with 46 additions and 27 deletions

67
main.py
View File

@ -1,4 +1,5 @@
import os
from typing import Counter
from util import Model
import spacy
import torch
@ -38,7 +39,7 @@ def compute_class_vector(mark, classes):
for x in range(len(classes)):
if classes[x] == mark[0]:
result[x] = 1.0
return torch.tensor(result, dtype=torch.long)
return torch.tensor(result, dtype=torch.float)
def prepare_input(index, data, context_size):
x = data[index: index + context_size]
@ -76,28 +77,32 @@ context_size = 5
model = Model()
epochs = 5
output_prefix = "model"
train_loss_acc = 30
device=torch.device("cuda")
model = model.cuda()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 15, 0.0001)
loss_function = torch.nn.MSELoss()
model.train()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.02)
loss_function = softXEnt
"""
TODO
1) dodać przetwarzanie baczowe
2) dodać osobną sieć w pełni połączoną która używa dźwięku żeby wykrywać czy użyć interpunkcji czy nie
2) zmienić loss function
"""
hidden_state = torch.randn((2, 1, 300), requires_grad=True).to(device)
cell_state = torch.randn((2, 1, 300), requires_grad=True).to(device)
# hidden_state = torch.randn((2, 1, 300), requires_grad=True).to(device)
# cell_state = torch.randn((2, 1, 300), requires_grad=True).to(device)
counter = 0
# model.load_state_dict(torch.load("model-4.pt"))
if mode == "train":
for epoch in range(epochs):
for path in tqdm.tqdm(data_paths):
for path in tqdm.tqdm(data_paths[:20]):
with open(path, mode="r", encoding="utf-8") as file:
list = file.readlines()
list = file.readlines()[:100]
for i in range(0, len(list) - context_size - 1 - 1):
model.zero_grad()
x = list[i: i + context_size]
@ -105,40 +110,54 @@ if mode == "train":
x_1 = [line2word(list[i + context_size + 1])]
mark = find_interpunction(x[-1], classes)
if mark == '':
continue
x = x + x_1
x = words_to_vecs(x)
x = torch.tensor(x, dtype=torch.float)
x = x.to(device)
output, (hidden_state, cell_state) = model.forward(x, hidden_state, cell_state)
output, (_,_) = model.forward(x)
output = output.squeeze(1)
loss = loss_function(output, compute_class_vector([mark], classes).to(device))
class_vector = compute_class_vector([mark], classes).to(device)
loss = loss_function(torch.mean(output, 0), class_vector)
if counter % 10 == 0:
print(torch.mean(output, 0))
print(loss)
print(class_vector)
loss.backward()
optimizer.step()
hidden_state = hidden_state.detach()
cell_state = cell_state.detach()
if counter % train_loss_acc == 0:
scheduler.step()
optimizer.step()
optimizer.zero_grad()
model.zero_grad()
counter += 1
print("Epoch: {}".format(epoch))
torch.save(
model.state_dict(),
os.path.join("./", f"{output_prefix}-{epoch}.pt"),
)
with open("hidden_state.pickle", "wb") as hs:
pickle.dump(hidden_state, hs)
with open("cell_state.pickle", "wb") as cs:
pickle.dump(cell_state, cs)
# with open("hidden_state.pickle", "wb") as hs:
# pickle.dump(hidden_state, hs)
# with open("cell_state.pickle", "wb") as cs:
# pickle.dump(cell_state, cs)
elif mode == "evaluate":
correct = 0
incorrect = 0
threshold = 0.3
model.load_state_dict(torch.load("model-0.pt"))
model.load_state_dict(torch.load("model-4.pt"))
model.eval()
with open("hidden_state.pickle", "rb") as hs:
hidden_state = pickle.load(hs)
with open("cell_state.pickle", "rb") as cs:
cell_state = pickle.load(cs)
# with open("hidden_state.pickle", "rb") as hs:
# hidden_state = pickle.load(hs)
# with open("cell_state.pickle", "rb") as cs:
# cell_state = pickle.load(cs)
for pathA, pathB in zip(data_no_punc_paths, data_paths):
with open(pathA, mode="r", encoding='utf-8') as file:
with open(pathB, mode="r", encoding='utf-8') as file2:
@ -156,7 +175,7 @@ elif mode == "evaluate":
mark_y = find_interpunction(x[-1], classes)
x = torch.tensor(x, dtype=torch.float).to(device)
output, (hidden_state, cell_state) = model.forward(x, hidden_state, cell_state)
output, (hidden_state, cell_state) = model.forward(x)
output = output.cpu()
output = output.detach().numpy()
output = np.mean(output, axis=0)
@ -164,7 +183,7 @@ elif mode == "evaluate":
result_index = np.argmax(output)
# if output[result_index] < threshold:
# incorrect += 1
print(output)
if len(mark_y) > 0:
if classes[np.argmax(output)] == mark_y:
correct += 1

View File

@ -20,11 +20,11 @@ class Model(torch.nn.Module):
self.dense2 = torch.nn.Linear(300, 8)
self.softmax = torch.nn.Softmax(dim=0)
def forward(self, data, hidden_state, cell_state):
def forward(self, data):
data = self.dense1(data.T)
data = self.tanh1(data)
data, (hidden_state, cell_state) = self.lstm(data.unsqueeze(1), (hidden_state, cell_state))
# data, (hidden_state, cell_state) = self.lstm(data.unsqueeze(1))
# data, (hidden_state, cell_state) = self.lstm(data.unsqueeze(1), (hidden_state, cell_state))
data, (hidden_state, cell_state) = self.lstm(data.unsqueeze(1))
data = self.dense2(data)
data = self.softmax(data)
return data, (hidden_state, cell_state)