en-ner-conll-2003/run.py

270 lines
7.8 KiB
Python
Raw Normal View History

2022-05-29 00:10:21 +02:00
import numpy as np
import gensim
import torch
import pandas as pd
import seaborn as sns
from sklearn.model_selection import train_test_split
import time
2022-05-30 22:43:33 +02:00
import copy
2022-05-29 00:10:21 +02:00
# from datasets import load_dataset
from torchtext.vocab import Vocab, vocab
from collections import Counter, OrderedDict
2022-05-30 22:43:33 +02:00
import string
2022-05-29 00:10:21 +02:00
# from sklearn.datasets import fetch_20newsgroups
# https://scikit-learn.org/0.19/datasets/twenty_newsgroups.html
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics import accuracy_score
import lzma
import torchtext.vocab
from sklearn.multiclass import OneVsRestClassifier
from sklearn.linear_model import LogisticRegression
from datetime import datetime
FEAUTERES = 10_000
2022-05-30 22:43:33 +02:00
Y_names = {
2022-05-29 00:10:21 +02:00
'O': 0,
'B-PER': 1,
'B-LOC': 2,
'I-PER': 3,
'B-MISC': 4,
'I-MISC': 5,
'I-LOC': 6,
'B-ORG': 7,
'I-ORG': 8
}
2022-05-30 22:43:33 +02:00
Y_names_re = {
2022-05-29 00:10:21 +02:00
0: 'O',
1: 'B-PER',
2: 'B-LOC',
3: 'I-PER',
4: 'B-MISC',
5: 'I-MISC',
6: 'I-LOC',
7: 'B-ORG',
8: 'I-ORG'
}
OUTPUT_SIZE = len(Y_names)
def build_vocab(dataset):
counter = Counter()
for document in dataset:
counter.update(document)
vocab = torchtext.vocab.vocab(counter, specials=['<unk>', '<pad>', '<bos>', '<eos>'])
vocab.set_default_index(0)
return vocab
def data_process(dt):
return [torch.tensor([vocab['<bos>']] + [vocab[token] for token in document] + [vocab['<eos>']], dtype=torch.long)
for document in dt]
def labels_process(data):
dt = []
for row in data:
dt.append([Y_names[i] for i in row])
return [torch.tensor([0] + document + [0], dtype=torch.long) for document in dt]
class NERModel(torch.nn.Module):
def __init__(self, ):
super(NERModel, self).__init__()
self.emb = torch.nn.Embedding(23627, 200)
2022-05-30 22:43:33 +02:00
self.fc1 = torch.nn.Linear(2000, 9)
2022-05-29 00:10:21 +02:00
# self.softmax = torch.nn.Softmax(dim=1)
# nie trzeba, bo używamy https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html
# jako kryterium
def forward(self, x):
x = self.emb(x)
2022-05-30 22:43:33 +02:00
x = x.reshape(2000)
2022-05-29 00:10:21 +02:00
x = self.fc1(x)
# x = self.softmax(x)
return x
def generate_out(folder_path):
print('Generating out')
X_dev = []
with open(f"{folder_path}/in.tsv", 'r') as file:
for line in file:
line = line.strip()
X_dev.append(line.split(' '))
print("step 5")
test_tokens_ids = data_process(X_dev)
predicted_values = []
# for i in range(100):
for i in range(len(test_tokens_ids)):
pred_string = ''
for j in range(1, len(test_tokens_ids[i]) - 1):
X = test_tokens_ids[i][j - 1: j + 2]
2022-05-30 22:43:33 +02:00
X_raw_single = X_dev[i][j - 1: j + 2]
X = manual_process(X, X_raw_single)
2022-05-29 00:10:21 +02:00
try:
Y_predictions = ner_model(X)
id = torch.argmax(Y_predictions)
val = Y_names_re[int(id)]
pred_string += val + ' '
except:
# import pdb
# pdb.set_trace()
pass
predicted_values.append(pred_string[:-1])
print("step 6")
# f = open(f"{folder_path}/out.tsv", "w")
lines = []
for line in predicted_values:
last_label = None
line = line.split(' ')
new_line = []
for label in line:
if (label != "O" and label[0:2] == "I-"):
if last_label == None or last_label == "O":
label = label.replace('I-', 'B-')
else:
label = "I-" + last_label[2:]
last_label = label
new_line.append(label)
lines.append(" ".join(new_line))
with open(f"{folder_path}/out.tsv", "w") as f:
for line in lines:
f.write(str(line) + "\n")
f.close()
# def predict(data):
# ner_model.eval()
# predictions = []
# for i in range(len(data)):
# predictions.append(ner_model(X))
# return predictions
2022-05-30 22:43:33 +02:00
def manual_process(tens, tokens):
array = [0, 0, 0, 0, 0, 0, 0]
if len(tokens) >= 2:
if len(tokens[1]) >= 1:
word = tokens[1]
if word[0].isupper():
array[0] = 1
if word.isalnum():
array[1] = 1
for i in word:
# checking whether the char is punctuation.
if i in string.punctuation:
# Printing the punctuation values
array[2] = 1
if word.isnumeric():
array[3] = 1
if word.isupper():
array[4] = 1
if '-' in word:
array[5] = 1
if '/' in word:
array[6] = 1
x = torch.tensor(array)
new_tensor = torch.cat((tens, x), 0)
return new_tensor
2022-05-29 00:10:21 +02:00
if __name__ == "__main__":
start_time = time.time()
X = []
Y = []
with lzma.open('train/train.tsv.xz', 'r') as file:
for line in file:
line = line.strip()
line = line.decode("utf-8")
tabs = line.rsplit('\t')
sentence = tabs[1]
tokens = tabs[0]
# pre_processed = gensim.utils.simple_preprocess(sentence)
sentence_array = sentence.split()
tokens_array = tokens.split()
if len(sentence_array) == len(tokens_array):
X.append(sentence_array)
# for token in tokens_array:
Y.append(tokens_array)
2022-05-30 22:43:33 +02:00
X_raw = copy.copy(X)
Y_raw = copy.copy(Y)
2022-05-29 00:10:21 +02:00
vocab = build_vocab(X)
train_tokens_ids = data_process(X)
train_labels = labels_process(Y)
ner_model = NERModel()
# ner_model(train_tokens_ids[10_000][1:4])
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(ner_model.parameters())
2022-05-30 22:43:33 +02:00
for epoch in range(4):
2022-05-29 00:10:21 +02:00
print('started epoch', epoch)
start_time_epoch = time.time()
loss_score = 0
acc_score = 0
prec_score = 0
selected_items = 0
recall_score = 0
relevant_items = 0
items_total = 0
ner_model.train()
2022-05-30 22:43:33 +02:00
for i in range(len(train_labels)-50):
# for i in range(20):
2022-05-29 00:10:21 +02:00
for j in range(1, len(train_labels[i]) - 1):
X = train_tokens_ids[i][j - 1: j + 2]
2022-05-30 22:43:33 +02:00
X_raw_single = X_raw[i][j - 1: j + 2]
2022-05-29 00:10:21 +02:00
Y = train_labels[i][j: j + 1]
2022-05-30 22:43:33 +02:00
X = manual_process(X, X_raw_single)
2022-05-29 00:10:21 +02:00
Y_predictions = ner_model(X)
acc_score += int(torch.argmax(Y_predictions) == Y)
if torch.argmax(Y_predictions) != 0:
selected_items += 1
if torch.argmax(Y_predictions) != 0 and torch.argmax(Y_predictions) == Y.item():
prec_score += 1
if Y.item() != 0:
relevant_items += 1
if Y.item() != 0 and torch.argmax(Y_predictions) == Y.item():
recall_score += 1
items_total += 1
optimizer.zero_grad()
loss = criterion(Y_predictions.unsqueeze(0), Y)
loss.backward()
optimizer.step()
loss_score += loss.item()
precision = prec_score / selected_items
recall = recall_score / relevant_items
if precision and recall:
f1_score = (2 * precision * recall) / (precision + recall)
else:
f1_score = 0
print('epoch: ', epoch)
print('loss: ', loss_score / items_total)
print('acc: ', acc_score / items_total)
print('prec: ', precision)
print('recall: : ', recall)
print('f1: ', f1_score)
print("--- %s seconds ---" % (time.time() - start_time_epoch))
print("--- %s seconds ---" % (time.time() - start_time))
print("Hello, World!")
generate_out('dev-0')
generate_out('test-A')