en-ner-conll-2003/main.py

136 lines
4.2 KiB
Python
Raw Normal View History

2021-06-20 19:04:16 +02:00
from os import sep
from nltk import word_tokenize
import pandas as pd
import torch
2021-06-20 22:03:34 +02:00
from torch._C import device
2021-06-20 19:04:16 +02:00
from tqdm import tqdm
from torchtext.vocab import vocab
from collections import Counter, OrderedDict
import spacy
from torchcrf import CRF
from torch.utils.data import DataLoader
nlp = spacy.load('en_core_web_sm')
class Model(torch.nn.Module):
def __init__(self, num_tags, seq_length):
super(Model, self).__init__()
self.emb = torch.nn.Embedding(len(vocab.get_itos()), 100)
self.gru = torch.nn.GRU(100, 256, 1, batch_first=True)
self.hidden2tag = torch.nn.Linear(256, 9)
self.crf = CRF(num_tags, batch_first=True)
self.relu = torch.nn.ReLU()
self.fc1 = torch.nn.Linear(1, seq_length)
self.softmax = torch.nn.Softmax(dim=0)
self.sigm = torch.nn.Sigmoid()
def forward(self, data, tags):
emb = self.relu(self.emb(data))
out, h_n = self.gru(emb)
# out = self.dense1(out.squeeze(0).T)
out = self.hidden2tag(out)
out = self.crf(out, tags.T)
2021-06-20 22:03:34 +02:00
# out = self.sigm(self.fc1(torch.tensor([out])))
return out
def decode(self, data):
emb = self.relu(self.emb(data))
out, h_n = self.gru(emb)
# out = self.dense1(out.squeeze(0).T)
out = self.hidden2tag(out)
out = self.crf.decode(out)
2021-06-20 19:04:16 +02:00
return out
def process_document(document):
# return [str(tok.lemma) for tok in nlp(document)]
return document.split(" ")
def build_vocab(dataset):
counter = Counter()
for document in dataset:
counter.update(process_document(document))
sorted_by_freq_tuples = sorted(counter.items(), key=lambda x: x[1], reverse=True)
ordered_dict = OrderedDict(sorted_by_freq_tuples)
v = vocab(counter)
default_index = -1
v.set_default_index(default_index)
return v
def data_process(dt):
return [ torch.tensor([vocab[token] for token in document.split(" ") ], dtype = torch.long) for document in dt]
def labels_process(dt):
return [ torch.tensor([labels_vocab[token] for token in document.split(" ") ], dtype = torch.long) for document in dt]
data = pd.read_csv("train/train.tsv", sep="\t")
data.columns = ["labels", "text"]
vocab = build_vocab(data['text'])
# labels_vocab = build_vocab(data['labels'])
labels_vocab = {
'O': 0,
'B-PER': 1,
'B-LOC': 2,
'I-PER': 3,
'B-MISC': 4,
'I-MISC': 5,
'I-LOC': 6,
'B-ORG': 7,
'I-ORG': 8
}
train_tokens_ids = data_process(data["text"])
train_labels = labels_process(data["labels"])
2021-06-20 22:03:34 +02:00
2021-06-20 19:04:16 +02:00
num_tags = 9
NUM_EPOCHS = 5
seq_length = 15
model = Model(num_tags, seq_length)
2021-06-20 22:03:34 +02:00
device = torch.device("cuda")
model.to(device)
model.cuda(0)
2021-06-20 19:04:16 +02:00
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())
train_dataloader = DataLoader(list(zip(train_tokens_ids, train_labels)), batch_size=64, shuffle=True)
# test_dataloader = DataLoader(train_labels, batch_size=64, shuffle=True)
2021-06-20 22:03:34 +02:00
mode = "train"
# mode = "eval"
# mode = "generate"
if mode == "train":
for i in range(NUM_EPOCHS):
model.train()
#for i in tqdm(range(500)):
for i in tqdm(range(len(train_labels))):
for k in range(0, len(train_tokens_ids[i]) - seq_length, seq_length):
batch_tokens = train_tokens_ids[i][k: k + seq_length].unsqueeze(0)
tags = train_labels[i][k: k + seq_length].unsqueeze(1)
predicted_tags = model(batch_tokens.to(device), tags.to(device))
optimizer.zero_grad()
# tags = torch.tensor([x[0] for x in tags])
# loss = criterion(predicted_tags.unsqueeze(0),tags.T)
predicted_tags.backward()
optimizer.step()
model.zero_grad()
torch.save(model.state_dict(), "model.torch")
if mode == "eval":
model.eval()
2021-06-20 19:04:16 +02:00
for i in tqdm(range(len(train_labels))):
for k in range(0, len(train_tokens_ids[i]) - seq_length, seq_length):
batch_tokens = train_tokens_ids[i][k: k + seq_length].unsqueeze(0)
tags = train_labels[i][k: k + seq_length].unsqueeze(1)
2021-06-20 22:03:34 +02:00
predicted_tags = model.decode(batch_tokens.to(device))
print('dupa')