From a2f39d1f5670455669a78fe9c3a8f8b6126718db Mon Sep 17 00:00:00 2001 From: wangobango Date: Sun, 20 Jun 2021 19:04:16 +0200 Subject: [PATCH] working on it --- main.py | 104 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 104 insertions(+) create mode 100644 main.py diff --git a/main.py b/main.py new file mode 100644 index 0000000..19bd6e2 --- /dev/null +++ b/main.py @@ -0,0 +1,104 @@ +from os import sep +from nltk import word_tokenize +import pandas as pd +import torch +from tqdm import tqdm +from torchtext.vocab import vocab +from collections import Counter, OrderedDict +import spacy +from torchcrf import CRF +from torch.utils.data import DataLoader + +nlp = spacy.load('en_core_web_sm') + +class Model(torch.nn.Module): + def __init__(self, num_tags, seq_length): + super(Model, self).__init__() + self.emb = torch.nn.Embedding(len(vocab.get_itos()), 100) + self.gru = torch.nn.GRU(100, 256, 1, batch_first=True) + self.hidden2tag = torch.nn.Linear(256, 9) + self.crf = CRF(num_tags, batch_first=True) + self.relu = torch.nn.ReLU() + self.fc1 = torch.nn.Linear(1, seq_length) + self.softmax = torch.nn.Softmax(dim=0) + self.sigm = torch.nn.Sigmoid() + + def forward(self, data, tags): + emb = self.relu(self.emb(data)) + out, h_n = self.gru(emb) + # out = self.dense1(out.squeeze(0).T) + out = self.hidden2tag(out) + out = self.crf(out, tags.T) + out = self.sigm(self.fc1(torch.tensor([out]))) + return out + + +def process_document(document): + # return [str(tok.lemma) for tok in nlp(document)] + return document.split(" ") + +def build_vocab(dataset): + counter = Counter() + for document in dataset: + counter.update(process_document(document)) + sorted_by_freq_tuples = sorted(counter.items(), key=lambda x: x[1], reverse=True) + ordered_dict = OrderedDict(sorted_by_freq_tuples) + v = vocab(counter) + default_index = -1 + v.set_default_index(default_index) + return v + +def data_process(dt): + return [ torch.tensor([vocab[token] for token in document.split(" ") ], dtype = torch.long) for document in dt] + +def labels_process(dt): + return [ torch.tensor([labels_vocab[token] for token in document.split(" ") ], dtype = torch.long) for document in dt] + +data = pd.read_csv("train/train.tsv", sep="\t") +data.columns = ["labels", "text"] + +vocab = build_vocab(data['text']) +# labels_vocab = build_vocab(data['labels']) + +labels_vocab = { + 'O': 0, + 'B-PER': 1, + 'B-LOC': 2, + 'I-PER': 3, + 'B-MISC': 4, + 'I-MISC': 5, + 'I-LOC': 6, + 'B-ORG': 7, + 'I-ORG': 8 +} + +train_tokens_ids = data_process(data["text"]) +train_labels = labels_process(data["labels"]) +num_tags = 9 +NUM_EPOCHS = 5 +seq_length = 15 + +model = Model(num_tags, seq_length) +criterion = torch.nn.CrossEntropyLoss() +optimizer = torch.optim.Adam(model.parameters()) + +train_dataloader = DataLoader(list(zip(train_tokens_ids, train_labels)), batch_size=64, shuffle=True) +# test_dataloader = DataLoader(train_labels, batch_size=64, shuffle=True) + +for i in range(NUM_EPOCHS): + model.train() + #for i in tqdm(range(500)): + for i in tqdm(range(len(train_labels))): + for k in range(0, len(train_tokens_ids[i]) - seq_length, seq_length): + batch_tokens = train_tokens_ids[i][k: k + seq_length].unsqueeze(0) + tags = train_labels[i][k: k + seq_length].unsqueeze(1) + + predicted_tags = model(batch_tokens, tags) + + optimizer.zero_grad() + tags = torch.tensor([x[0] for x in tags]) + loss = criterion(predicted_tags.unsqueeze(0),tags.T) + + loss.backward() + optimizer.step() + model.zero_grad() \ No newline at end of file