This commit is contained in:
wangobango 2021-06-20 22:03:34 +02:00
parent a2f39d1f56
commit 434e164ea3
1 changed files with 44 additions and 13 deletions

57
main.py
View File

@ -2,6 +2,7 @@ from os import sep
from nltk import word_tokenize from nltk import word_tokenize
import pandas as pd import pandas as pd
import torch import torch
from torch._C import device
from tqdm import tqdm from tqdm import tqdm
from torchtext.vocab import vocab from torchtext.vocab import vocab
from collections import Counter, OrderedDict from collections import Counter, OrderedDict
@ -29,7 +30,15 @@ class Model(torch.nn.Module):
# out = self.dense1(out.squeeze(0).T) # out = self.dense1(out.squeeze(0).T)
out = self.hidden2tag(out) out = self.hidden2tag(out)
out = self.crf(out, tags.T) out = self.crf(out, tags.T)
out = self.sigm(self.fc1(torch.tensor([out]))) # out = self.sigm(self.fc1(torch.tensor([out])))
return out
def decode(self, data):
emb = self.relu(self.emb(data))
out, h_n = self.gru(emb)
# out = self.dense1(out.squeeze(0).T)
out = self.hidden2tag(out)
out = self.crf.decode(out)
return out return out
@ -74,31 +83,53 @@ labels_vocab = {
train_tokens_ids = data_process(data["text"]) train_tokens_ids = data_process(data["text"])
train_labels = labels_process(data["labels"]) train_labels = labels_process(data["labels"])
num_tags = 9 num_tags = 9
NUM_EPOCHS = 5 NUM_EPOCHS = 5
seq_length = 15 seq_length = 15
model = Model(num_tags, seq_length) model = Model(num_tags, seq_length)
device = torch.device("cuda")
model.to(device)
model.cuda(0)
criterion = torch.nn.CrossEntropyLoss() criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters()) optimizer = torch.optim.Adam(model.parameters())
train_dataloader = DataLoader(list(zip(train_tokens_ids, train_labels)), batch_size=64, shuffle=True) train_dataloader = DataLoader(list(zip(train_tokens_ids, train_labels)), batch_size=64, shuffle=True)
# test_dataloader = DataLoader(train_labels, batch_size=64, shuffle=True) # test_dataloader = DataLoader(train_labels, batch_size=64, shuffle=True)
for i in range(NUM_EPOCHS): mode = "train"
model.train() # mode = "eval"
#for i in tqdm(range(500)): # mode = "generate"
if mode == "train":
for i in range(NUM_EPOCHS):
model.train()
#for i in tqdm(range(500)):
for i in tqdm(range(len(train_labels))):
for k in range(0, len(train_tokens_ids[i]) - seq_length, seq_length):
batch_tokens = train_tokens_ids[i][k: k + seq_length].unsqueeze(0)
tags = train_labels[i][k: k + seq_length].unsqueeze(1)
predicted_tags = model(batch_tokens.to(device), tags.to(device))
optimizer.zero_grad()
# tags = torch.tensor([x[0] for x in tags])
# loss = criterion(predicted_tags.unsqueeze(0),tags.T)
predicted_tags.backward()
optimizer.step()
model.zero_grad()
torch.save(model.state_dict(), "model.torch")
if mode == "eval":
model.eval()
for i in tqdm(range(len(train_labels))): for i in tqdm(range(len(train_labels))):
for k in range(0, len(train_tokens_ids[i]) - seq_length, seq_length): for k in range(0, len(train_tokens_ids[i]) - seq_length, seq_length):
batch_tokens = train_tokens_ids[i][k: k + seq_length].unsqueeze(0) batch_tokens = train_tokens_ids[i][k: k + seq_length].unsqueeze(0)
tags = train_labels[i][k: k + seq_length].unsqueeze(1) tags = train_labels[i][k: k + seq_length].unsqueeze(1)
predicted_tags = model(batch_tokens, tags) predicted_tags = model.decode(batch_tokens.to(device))
print('dupa')
optimizer.zero_grad()
tags = torch.tensor([x[0] for x in tags])
loss = criterion(predicted_tags.unsqueeze(0),tags.T)
loss.backward()
optimizer.step()
model.zero_grad()