working on it

This commit is contained in:
wangobango 2021-06-20 19:04:16 +02:00
parent 1397a7a5c2
commit a2f39d1f56

104
main.py Normal file
View File

@ -0,0 +1,104 @@
from os import sep
from nltk import word_tokenize
import pandas as pd
import torch
from tqdm import tqdm
from torchtext.vocab import vocab
from collections import Counter, OrderedDict
import spacy
from torchcrf import CRF
from torch.utils.data import DataLoader
nlp = spacy.load('en_core_web_sm')
class Model(torch.nn.Module):
def __init__(self, num_tags, seq_length):
super(Model, self).__init__()
self.emb = torch.nn.Embedding(len(vocab.get_itos()), 100)
self.gru = torch.nn.GRU(100, 256, 1, batch_first=True)
self.hidden2tag = torch.nn.Linear(256, 9)
self.crf = CRF(num_tags, batch_first=True)
self.relu = torch.nn.ReLU()
self.fc1 = torch.nn.Linear(1, seq_length)
self.softmax = torch.nn.Softmax(dim=0)
self.sigm = torch.nn.Sigmoid()
def forward(self, data, tags):
emb = self.relu(self.emb(data))
out, h_n = self.gru(emb)
# out = self.dense1(out.squeeze(0).T)
out = self.hidden2tag(out)
out = self.crf(out, tags.T)
out = self.sigm(self.fc1(torch.tensor([out])))
return out
def process_document(document):
# return [str(tok.lemma) for tok in nlp(document)]
return document.split(" ")
def build_vocab(dataset):
counter = Counter()
for document in dataset:
counter.update(process_document(document))
sorted_by_freq_tuples = sorted(counter.items(), key=lambda x: x[1], reverse=True)
ordered_dict = OrderedDict(sorted_by_freq_tuples)
v = vocab(counter)
default_index = -1
v.set_default_index(default_index)
return v
def data_process(dt):
return [ torch.tensor([vocab[token] for token in document.split(" ") ], dtype = torch.long) for document in dt]
def labels_process(dt):
return [ torch.tensor([labels_vocab[token] for token in document.split(" ") ], dtype = torch.long) for document in dt]
data = pd.read_csv("train/train.tsv", sep="\t")
data.columns = ["labels", "text"]
vocab = build_vocab(data['text'])
# labels_vocab = build_vocab(data['labels'])
labels_vocab = {
'O': 0,
'B-PER': 1,
'B-LOC': 2,
'I-PER': 3,
'B-MISC': 4,
'I-MISC': 5,
'I-LOC': 6,
'B-ORG': 7,
'I-ORG': 8
}
train_tokens_ids = data_process(data["text"])
train_labels = labels_process(data["labels"])
num_tags = 9
NUM_EPOCHS = 5
seq_length = 15
model = Model(num_tags, seq_length)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())
train_dataloader = DataLoader(list(zip(train_tokens_ids, train_labels)), batch_size=64, shuffle=True)
# test_dataloader = DataLoader(train_labels, batch_size=64, shuffle=True)
for i in range(NUM_EPOCHS):
model.train()
#for i in tqdm(range(500)):
for i in tqdm(range(len(train_labels))):
for k in range(0, len(train_tokens_ids[i]) - seq_length, seq_length):
batch_tokens = train_tokens_ids[i][k: k + seq_length].unsqueeze(0)
tags = train_labels[i][k: k + seq_length].unsqueeze(1)
predicted_tags = model(batch_tokens, tags)
optimizer.zero_grad()
tags = torch.tensor([x[0] for x in tags])
loss = criterion(predicted_tags.unsqueeze(0),tags.T)
loss.backward()
optimizer.step()
model.zero_grad()