7.3 KiB
7.3 KiB
from collections import Counter
import torch
from datasets import load_dataset
from torchtext.vocab import vocab
from tqdm.notebook import tqdm
import pandas as pd
from nltk.tokenize import word_tokenize
import string
from transformers import pipeline, AutoModelForTokenClassification, AutoTokenizer
C:\Users\obses\AppData\Local\Programs\Python\Python310\lib\site-packages\torchtext\vocab\__init__.py:4: UserWarning: /!\ IMPORTANT WARNING ABOUT TORCHTEXT STATUS /!\ Torchtext is deprecated and the last released version will be 0.18 (this one). You can silence this warning by calling the following at the beginnign of your scripts: `import torchtext; torchtext.disable_torchtext_deprecation_warning()` warnings.warn(torchtext._TORCHTEXT_DEPRECATION_MSG) C:\Users\obses\AppData\Local\Programs\Python\Python310\lib\site-packages\torchtext\utils.py:4: UserWarning: /!\ IMPORTANT WARNING ABOUT TORCHTEXT STATUS /!\ Torchtext is deprecated and the last released version will be 0.18 (this one). You can silence this warning by calling the following at the beginnign of your scripts: `import torchtext; torchtext.disable_torchtext_deprecation_warning()` warnings.warn(torchtext._TORCHTEXT_DEPRECATION_MSG)
dataset = pd.read_csv('train.tsv', sep='\t', header=None)
X_test = pd.read_csv('../dev-0/in.tsv', sep='\t', header=None)
Y_test = pd.read_csv('../dev-0/expected.tsv', sep='\t', header=None)
X_train = dataset[dataset.columns[1]]
Y_train = dataset[dataset.columns[0]]
X_test = X_test[X_test.columns[0]]
Y_test = Y_test[Y_test.columns[0]]
X_train = [text.split() for text in X_train]
X_test = [text.split() for text in X_test]
model = AutoModelForTokenClassification.from_pretrained("dbmdz/bert-large-cased-finetuned-conll03-english")
tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-cased")
recognizer = pipeline("ner", model=model, tokenizer=tokenizer)
full_tags = []
for idx, el in tqdm(enumerate(X_test)):
out_tags = []
tags_corrected = []
temp = []
tags_joined = []
tag = 'NULL'
t = ""
for r in recognizer(" ".join(el)):
if len(t) == 0:
t = r['word']
tag = r['entity']
continue
if "#" in r['word']:
t = t + str(r['word']).replace("#","")
continue
if "#" not in r['word'] and len(t) != 0:
temp.append(t)
tags_joined.append(tag)
t = r['word']
tag = r['entity']
for tag in Y_test[idx].split():
if tag == "O":
out_tags.append("O")
if tag != "O" and len(tags_joined) > 0:
out_tags.append(tags_joined[0])
tags_joined = tags_joined[1:]
continue
if tag != "O" and len(tags_joined) == 0:
out_tags.append("O")
#print(len(Y_test[idx].split()), len(out_tags))
out_tags = " ".join(out_tags).replace("I-","B-").split()
last_tag = out_tags[0]
tags_corrected.append(last_tag)
for tag in out_tags[1:]:
if tag == last_tag:
tags_corrected.append(tag.replace("B-","I-"))
last_tag = tag
else:
last_tag = tag
tags_corrected.append(tag)
#print(len(Y_test[idx].split()), len(tags_corrected))
full_tags.append(tags_corrected)
print(len(full_tags))
# for idx, el in tqdm(enumerate(full_tags)):
# if len(el) != len(len(Y_test[idx].split())):
# print("Somethings wrong sir")
Some weights of the model checkpoint at dbmdz/bert-large-cased-finetuned-conll03-english were not used when initializing BertForTokenClassification: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight'] - This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model). - This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
0it [00:00, ?it/s]
215
with open("out.tsv", 'w') as file:
for el in full_tags:
file.write(" ".join(el))
file.write(f"\n")