From 434e164ea330f3a950787d7ffa7d48a9e6ee03fe Mon Sep 17 00:00:00 2001 From: wangobango Date: Sun, 20 Jun 2021 22:03:34 +0200 Subject: [PATCH] progress --- main.py | 57 ++++++++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 44 insertions(+), 13 deletions(-) diff --git a/main.py b/main.py index 19bd6e2..547e290 100644 --- a/main.py +++ b/main.py @@ -2,6 +2,7 @@ from os import sep from nltk import word_tokenize import pandas as pd import torch +from torch._C import device from tqdm import tqdm from torchtext.vocab import vocab from collections import Counter, OrderedDict @@ -29,7 +30,15 @@ class Model(torch.nn.Module): # out = self.dense1(out.squeeze(0).T) out = self.hidden2tag(out) out = self.crf(out, tags.T) - out = self.sigm(self.fc1(torch.tensor([out]))) + # out = self.sigm(self.fc1(torch.tensor([out]))) + return out + + def decode(self, data): + emb = self.relu(self.emb(data)) + out, h_n = self.gru(emb) + # out = self.dense1(out.squeeze(0).T) + out = self.hidden2tag(out) + out = self.crf.decode(out) return out @@ -74,31 +83,53 @@ labels_vocab = { train_tokens_ids = data_process(data["text"]) train_labels = labels_process(data["labels"]) + num_tags = 9 NUM_EPOCHS = 5 seq_length = 15 model = Model(num_tags, seq_length) +device = torch.device("cuda") +model.to(device) +model.cuda(0) + criterion = torch.nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters()) train_dataloader = DataLoader(list(zip(train_tokens_ids, train_labels)), batch_size=64, shuffle=True) # test_dataloader = DataLoader(train_labels, batch_size=64, shuffle=True) -for i in range(NUM_EPOCHS): - model.train() - #for i in tqdm(range(500)): +mode = "train" +# mode = "eval" +# mode = "generate" + +if mode == "train": + for i in range(NUM_EPOCHS): + model.train() + #for i in tqdm(range(500)): + for i in tqdm(range(len(train_labels))): + for k in range(0, len(train_tokens_ids[i]) - seq_length, seq_length): + batch_tokens = train_tokens_ids[i][k: k + seq_length].unsqueeze(0) + tags = train_labels[i][k: k + seq_length].unsqueeze(1) + + predicted_tags = model(batch_tokens.to(device), tags.to(device)) + + optimizer.zero_grad() + # tags = torch.tensor([x[0] for x in tags]) + # loss = criterion(predicted_tags.unsqueeze(0),tags.T) + + predicted_tags.backward() + optimizer.step() + model.zero_grad() + + torch.save(model.state_dict(), "model.torch") + +if mode == "eval": + model.eval() for i in tqdm(range(len(train_labels))): for k in range(0, len(train_tokens_ids[i]) - seq_length, seq_length): batch_tokens = train_tokens_ids[i][k: k + seq_length].unsqueeze(0) tags = train_labels[i][k: k + seq_length].unsqueeze(1) - predicted_tags = model(batch_tokens, tags) - - optimizer.zero_grad() - tags = torch.tensor([x[0] for x in tags]) - loss = criterion(predicted_tags.unsqueeze(0),tags.T) - - loss.backward() - optimizer.step() - model.zero_grad() \ No newline at end of file + predicted_tags = model.decode(batch_tokens.to(device)) + print('dupa')