forked from kubapok/en-ner-conll-2003
working on it
This commit is contained in:
parent
1397a7a5c2
commit
a2f39d1f56
104
main.py
Normal file
104
main.py
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
from os import sep
|
||||||
|
from nltk import word_tokenize
|
||||||
|
import pandas as pd
|
||||||
|
import torch
|
||||||
|
from tqdm import tqdm
|
||||||
|
from torchtext.vocab import vocab
|
||||||
|
from collections import Counter, OrderedDict
|
||||||
|
import spacy
|
||||||
|
from torchcrf import CRF
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
nlp = spacy.load('en_core_web_sm')
|
||||||
|
|
||||||
|
class Model(torch.nn.Module):
|
||||||
|
def __init__(self, num_tags, seq_length):
|
||||||
|
super(Model, self).__init__()
|
||||||
|
self.emb = torch.nn.Embedding(len(vocab.get_itos()), 100)
|
||||||
|
self.gru = torch.nn.GRU(100, 256, 1, batch_first=True)
|
||||||
|
self.hidden2tag = torch.nn.Linear(256, 9)
|
||||||
|
self.crf = CRF(num_tags, batch_first=True)
|
||||||
|
self.relu = torch.nn.ReLU()
|
||||||
|
self.fc1 = torch.nn.Linear(1, seq_length)
|
||||||
|
self.softmax = torch.nn.Softmax(dim=0)
|
||||||
|
self.sigm = torch.nn.Sigmoid()
|
||||||
|
|
||||||
|
def forward(self, data, tags):
|
||||||
|
emb = self.relu(self.emb(data))
|
||||||
|
out, h_n = self.gru(emb)
|
||||||
|
# out = self.dense1(out.squeeze(0).T)
|
||||||
|
out = self.hidden2tag(out)
|
||||||
|
out = self.crf(out, tags.T)
|
||||||
|
out = self.sigm(self.fc1(torch.tensor([out])))
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def process_document(document):
|
||||||
|
# return [str(tok.lemma) for tok in nlp(document)]
|
||||||
|
return document.split(" ")
|
||||||
|
|
||||||
|
def build_vocab(dataset):
|
||||||
|
counter = Counter()
|
||||||
|
for document in dataset:
|
||||||
|
counter.update(process_document(document))
|
||||||
|
sorted_by_freq_tuples = sorted(counter.items(), key=lambda x: x[1], reverse=True)
|
||||||
|
ordered_dict = OrderedDict(sorted_by_freq_tuples)
|
||||||
|
v = vocab(counter)
|
||||||
|
default_index = -1
|
||||||
|
v.set_default_index(default_index)
|
||||||
|
return v
|
||||||
|
|
||||||
|
def data_process(dt):
|
||||||
|
return [ torch.tensor([vocab[token] for token in document.split(" ") ], dtype = torch.long) for document in dt]
|
||||||
|
|
||||||
|
def labels_process(dt):
|
||||||
|
return [ torch.tensor([labels_vocab[token] for token in document.split(" ") ], dtype = torch.long) for document in dt]
|
||||||
|
|
||||||
|
data = pd.read_csv("train/train.tsv", sep="\t")
|
||||||
|
data.columns = ["labels", "text"]
|
||||||
|
|
||||||
|
vocab = build_vocab(data['text'])
|
||||||
|
# labels_vocab = build_vocab(data['labels'])
|
||||||
|
|
||||||
|
labels_vocab = {
|
||||||
|
'O': 0,
|
||||||
|
'B-PER': 1,
|
||||||
|
'B-LOC': 2,
|
||||||
|
'I-PER': 3,
|
||||||
|
'B-MISC': 4,
|
||||||
|
'I-MISC': 5,
|
||||||
|
'I-LOC': 6,
|
||||||
|
'B-ORG': 7,
|
||||||
|
'I-ORG': 8
|
||||||
|
}
|
||||||
|
|
||||||
|
train_tokens_ids = data_process(data["text"])
|
||||||
|
train_labels = labels_process(data["labels"])
|
||||||
|
num_tags = 9
|
||||||
|
NUM_EPOCHS = 5
|
||||||
|
seq_length = 15
|
||||||
|
|
||||||
|
model = Model(num_tags, seq_length)
|
||||||
|
criterion = torch.nn.CrossEntropyLoss()
|
||||||
|
optimizer = torch.optim.Adam(model.parameters())
|
||||||
|
|
||||||
|
train_dataloader = DataLoader(list(zip(train_tokens_ids, train_labels)), batch_size=64, shuffle=True)
|
||||||
|
# test_dataloader = DataLoader(train_labels, batch_size=64, shuffle=True)
|
||||||
|
|
||||||
|
for i in range(NUM_EPOCHS):
|
||||||
|
model.train()
|
||||||
|
#for i in tqdm(range(500)):
|
||||||
|
for i in tqdm(range(len(train_labels))):
|
||||||
|
for k in range(0, len(train_tokens_ids[i]) - seq_length, seq_length):
|
||||||
|
batch_tokens = train_tokens_ids[i][k: k + seq_length].unsqueeze(0)
|
||||||
|
tags = train_labels[i][k: k + seq_length].unsqueeze(1)
|
||||||
|
|
||||||
|
predicted_tags = model(batch_tokens, tags)
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
tags = torch.tensor([x[0] for x in tags])
|
||||||
|
loss = criterion(predicted_tags.unsqueeze(0),tags.T)
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
model.zero_grad()
|
Loading…
Reference in New Issue
Block a user