From 50621d5a7fe14984e1276db3d629f740eaef5668 Mon Sep 17 00:00:00 2001 From: wangobango Date: Mon, 21 Jun 2021 00:43:43 +0200 Subject: [PATCH] its working --- main.py | 49 ++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 40 insertions(+), 9 deletions(-) diff --git a/main.py b/main.py index 547e290..cf317ca 100644 --- a/main.py +++ b/main.py @@ -9,6 +9,9 @@ from collections import Counter, OrderedDict import spacy from torchcrf import CRF from torch.utils.data import DataLoader +import numpy as np +from sklearn.metrics import accuracy_score, f1_score, classification_report + nlp = spacy.load('en_core_web_sm') @@ -31,7 +34,7 @@ class Model(torch.nn.Module): out = self.hidden2tag(out) out = self.crf(out, tags.T) # out = self.sigm(self.fc1(torch.tensor([out]))) - return out + return -out def decode(self, data): emb = self.relu(self.emb(data)) @@ -41,6 +44,12 @@ class Model(torch.nn.Module): out = self.crf.decode(out) return out + def train_mode(self): + self.crf.train() + + def eval_mode(self): + self.crf.eval() + def process_document(document): # return [str(tok.lemma) for tok in nlp(document)] @@ -62,7 +71,7 @@ def data_process(dt): def labels_process(dt): return [ torch.tensor([labels_vocab[token] for token in document.split(" ") ], dtype = torch.long) for document in dt] - +save_path = "train/out.tsv" data = pd.read_csv("train/train.tsv", sep="\t") data.columns = ["labels", "text"] @@ -81,12 +90,14 @@ labels_vocab = { 'I-ORG': 8 } +inv_labels_vocab = {v: k for k, v in labels_vocab.items()} + train_tokens_ids = data_process(data["text"]) train_labels = labels_process(data["labels"]) num_tags = 9 NUM_EPOCHS = 5 -seq_length = 15 +seq_length = 5 model = Model(num_tags, seq_length) device = torch.device("cuda") @@ -99,13 +110,14 @@ optimizer = torch.optim.Adam(model.parameters()) train_dataloader = DataLoader(list(zip(train_tokens_ids, train_labels)), batch_size=64, shuffle=True) # test_dataloader = DataLoader(train_labels, batch_size=64, shuffle=True) -mode = "train" -# mode = "eval" +# mode = "train" +mode = "eval" # mode = "generate" if mode == "train": for i in range(NUM_EPOCHS): model.train() + model.train_mode() #for i in tqdm(range(500)): for i in tqdm(range(len(train_labels))): for k in range(0, len(train_tokens_ids[i]) - seq_length, seq_length): @@ -114,22 +126,41 @@ if mode == "train": predicted_tags = model(batch_tokens.to(device), tags.to(device)) - optimizer.zero_grad() # tags = torch.tensor([x[0] for x in tags]) # loss = criterion(predicted_tags.unsqueeze(0),tags.T) predicted_tags.backward() optimizer.step() model.zero_grad() + model.crf.zero_grad() + optimizer.zero_grad() torch.save(model.state_dict(), "model.torch") if mode == "eval": model.eval() - for i in tqdm(range(len(train_labels))): + model.eval_mode() + predicted = [] + correct = [] + model.load_state_dict(torch.load("model.torch")) + for i in tqdm(range(0, len(train_labels))): for k in range(0, len(train_tokens_ids[i]) - seq_length, seq_length): batch_tokens = train_tokens_ids[i][k: k + seq_length].unsqueeze(0) tags = train_labels[i][k: k + seq_length].unsqueeze(1) - predicted_tags = model.decode(batch_tokens.to(device)) - print('dupa') + predicted += predicted_tags[0] + correct += [x[0] for x in tags.numpy().tolist()] + print(classification_report(correct, predicted)) + print(accuracy_score(correct, predicted)) + print(f1_score(correct, predicted, average="weighted")) + + predicted = list(map(lambda x: inv_labels_vocab[x], predicted)) + slices = [len(x.split(" ")) for x in data["text"]] + with open(save_path, "a") as save: + accumulator = 0 + for slice in slices: + save.write(' '.join(predicted[accumulator: accumulator + slice])) + accumulator += slice + + + \ No newline at end of file