dupa
This commit is contained in:
parent
d013ea535d
commit
a894711f00
67
main.py
67
main.py
@ -1,4 +1,5 @@
|
||||
import os
|
||||
from typing import Counter
|
||||
from util import Model
|
||||
import spacy
|
||||
import torch
|
||||
@ -38,7 +39,7 @@ def compute_class_vector(mark, classes):
|
||||
for x in range(len(classes)):
|
||||
if classes[x] == mark[0]:
|
||||
result[x] = 1.0
|
||||
return torch.tensor(result, dtype=torch.long)
|
||||
return torch.tensor(result, dtype=torch.float)
|
||||
|
||||
def prepare_input(index, data, context_size):
|
||||
x = data[index: index + context_size]
|
||||
@ -76,28 +77,32 @@ context_size = 5
|
||||
model = Model()
|
||||
epochs = 5
|
||||
output_prefix = "model"
|
||||
train_loss_acc = 30
|
||||
|
||||
device=torch.device("cuda")
|
||||
model = model.cuda()
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)
|
||||
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, 15, 0.0001)
|
||||
loss_function = torch.nn.MSELoss()
|
||||
model.train()
|
||||
optimizer = torch.optim.AdamW(model.parameters(), lr=0.02)
|
||||
loss_function = softXEnt
|
||||
|
||||
"""
|
||||
TODO
|
||||
1) dodać przetwarzanie baczowe
|
||||
2) dodać osobną sieć w pełni połączoną która używa dźwięku żeby wykrywać czy użyć interpunkcji czy nie
|
||||
2) zmienić loss function
|
||||
"""
|
||||
|
||||
|
||||
hidden_state = torch.randn((2, 1, 300), requires_grad=True).to(device)
|
||||
cell_state = torch.randn((2, 1, 300), requires_grad=True).to(device)
|
||||
# hidden_state = torch.randn((2, 1, 300), requires_grad=True).to(device)
|
||||
# cell_state = torch.randn((2, 1, 300), requires_grad=True).to(device)
|
||||
counter = 0
|
||||
# model.load_state_dict(torch.load("model-4.pt"))
|
||||
|
||||
if mode == "train":
|
||||
for epoch in range(epochs):
|
||||
for path in tqdm.tqdm(data_paths):
|
||||
for path in tqdm.tqdm(data_paths[:20]):
|
||||
with open(path, mode="r", encoding="utf-8") as file:
|
||||
list = file.readlines()
|
||||
list = file.readlines()[:100]
|
||||
for i in range(0, len(list) - context_size - 1 - 1):
|
||||
model.zero_grad()
|
||||
x = list[i: i + context_size]
|
||||
@ -105,40 +110,54 @@ if mode == "train":
|
||||
x_1 = [line2word(list[i + context_size + 1])]
|
||||
mark = find_interpunction(x[-1], classes)
|
||||
|
||||
if mark == '':
|
||||
continue
|
||||
|
||||
x = x + x_1
|
||||
x = words_to_vecs(x)
|
||||
|
||||
x = torch.tensor(x, dtype=torch.float)
|
||||
|
||||
x = x.to(device)
|
||||
output, (hidden_state, cell_state) = model.forward(x, hidden_state, cell_state)
|
||||
output, (_,_) = model.forward(x)
|
||||
output = output.squeeze(1)
|
||||
loss = loss_function(output, compute_class_vector([mark], classes).to(device))
|
||||
class_vector = compute_class_vector([mark], classes).to(device)
|
||||
loss = loss_function(torch.mean(output, 0), class_vector)
|
||||
|
||||
if counter % 10 == 0:
|
||||
print(torch.mean(output, 0))
|
||||
print(loss)
|
||||
print(class_vector)
|
||||
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
hidden_state = hidden_state.detach()
|
||||
cell_state = cell_state.detach()
|
||||
if counter % train_loss_acc == 0:
|
||||
scheduler.step()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
model.zero_grad()
|
||||
|
||||
counter += 1
|
||||
|
||||
print("Epoch: {}".format(epoch))
|
||||
torch.save(
|
||||
model.state_dict(),
|
||||
os.path.join("./", f"{output_prefix}-{epoch}.pt"),
|
||||
)
|
||||
with open("hidden_state.pickle", "wb") as hs:
|
||||
pickle.dump(hidden_state, hs)
|
||||
with open("cell_state.pickle", "wb") as cs:
|
||||
pickle.dump(cell_state, cs)
|
||||
# with open("hidden_state.pickle", "wb") as hs:
|
||||
# pickle.dump(hidden_state, hs)
|
||||
# with open("cell_state.pickle", "wb") as cs:
|
||||
# pickle.dump(cell_state, cs)
|
||||
|
||||
elif mode == "evaluate":
|
||||
correct = 0
|
||||
incorrect = 0
|
||||
threshold = 0.3
|
||||
model.load_state_dict(torch.load("model-0.pt"))
|
||||
model.load_state_dict(torch.load("model-4.pt"))
|
||||
model.eval()
|
||||
with open("hidden_state.pickle", "rb") as hs:
|
||||
hidden_state = pickle.load(hs)
|
||||
with open("cell_state.pickle", "rb") as cs:
|
||||
cell_state = pickle.load(cs)
|
||||
# with open("hidden_state.pickle", "rb") as hs:
|
||||
# hidden_state = pickle.load(hs)
|
||||
# with open("cell_state.pickle", "rb") as cs:
|
||||
# cell_state = pickle.load(cs)
|
||||
for pathA, pathB in zip(data_no_punc_paths, data_paths):
|
||||
with open(pathA, mode="r", encoding='utf-8') as file:
|
||||
with open(pathB, mode="r", encoding='utf-8') as file2:
|
||||
@ -156,7 +175,7 @@ elif mode == "evaluate":
|
||||
mark_y = find_interpunction(x[-1], classes)
|
||||
x = torch.tensor(x, dtype=torch.float).to(device)
|
||||
|
||||
output, (hidden_state, cell_state) = model.forward(x, hidden_state, cell_state)
|
||||
output, (hidden_state, cell_state) = model.forward(x)
|
||||
output = output.cpu()
|
||||
output = output.detach().numpy()
|
||||
output = np.mean(output, axis=0)
|
||||
@ -164,7 +183,7 @@ elif mode == "evaluate":
|
||||
result_index = np.argmax(output)
|
||||
# if output[result_index] < threshold:
|
||||
# incorrect += 1
|
||||
|
||||
print(output)
|
||||
if len(mark_y) > 0:
|
||||
if classes[np.argmax(output)] == mark_y:
|
||||
correct += 1
|
||||
|
6
util.py
6
util.py
@ -20,11 +20,11 @@ class Model(torch.nn.Module):
|
||||
self.dense2 = torch.nn.Linear(300, 8)
|
||||
self.softmax = torch.nn.Softmax(dim=0)
|
||||
|
||||
def forward(self, data, hidden_state, cell_state):
|
||||
def forward(self, data):
|
||||
data = self.dense1(data.T)
|
||||
data = self.tanh1(data)
|
||||
data, (hidden_state, cell_state) = self.lstm(data.unsqueeze(1), (hidden_state, cell_state))
|
||||
# data, (hidden_state, cell_state) = self.lstm(data.unsqueeze(1))
|
||||
# data, (hidden_state, cell_state) = self.lstm(data.unsqueeze(1), (hidden_state, cell_state))
|
||||
data, (hidden_state, cell_state) = self.lstm(data.unsqueeze(1))
|
||||
data = self.dense2(data)
|
||||
data = self.softmax(data)
|
||||
return data, (hidden_state, cell_state)
|
Loading…
Reference in New Issue
Block a user