transformer/Transformer.ipynb
2024-06-09 13:43:22 +02:00

7.3 KiB

from collections import Counter
import torch
from datasets import load_dataset
from torchtext.vocab import vocab
from tqdm.notebook import tqdm
import pandas as pd
from nltk.tokenize import word_tokenize
import string

from transformers import pipeline, AutoModelForTokenClassification, AutoTokenizer
C:\Users\obses\AppData\Local\Programs\Python\Python310\lib\site-packages\torchtext\vocab\__init__.py:4: UserWarning: 
/!\ IMPORTANT WARNING ABOUT TORCHTEXT STATUS /!\ 
Torchtext is deprecated and the last released version will be 0.18 (this one). You can silence this warning by calling the following at the beginnign of your scripts: `import torchtext; torchtext.disable_torchtext_deprecation_warning()`
  warnings.warn(torchtext._TORCHTEXT_DEPRECATION_MSG)
C:\Users\obses\AppData\Local\Programs\Python\Python310\lib\site-packages\torchtext\utils.py:4: UserWarning: 
/!\ IMPORTANT WARNING ABOUT TORCHTEXT STATUS /!\ 
Torchtext is deprecated and the last released version will be 0.18 (this one). You can silence this warning by calling the following at the beginnign of your scripts: `import torchtext; torchtext.disable_torchtext_deprecation_warning()`
  warnings.warn(torchtext._TORCHTEXT_DEPRECATION_MSG)
dataset = pd.read_csv('train.tsv', sep='\t', header=None)
X_test = pd.read_csv('../dev-0/in.tsv', sep='\t', header=None)
Y_test = pd.read_csv('../dev-0/expected.tsv', sep='\t', header=None)
X_train = dataset[dataset.columns[1]]
Y_train = dataset[dataset.columns[0]]

X_test = X_test[X_test.columns[0]]
Y_test = Y_test[Y_test.columns[0]]
X_train = [text.split() for text in X_train]
X_test = [text.split() for text in X_test]
model = AutoModelForTokenClassification.from_pretrained("dbmdz/bert-large-cased-finetuned-conll03-english")
tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-cased")

recognizer = pipeline("ner", model=model, tokenizer=tokenizer)
full_tags = []
for idx, el in tqdm(enumerate(X_test)):
    out_tags = []
    tags_corrected = []

    temp = []
    tags_joined = []
    tag = 'NULL'
    t = ""
    for r in recognizer(" ".join(el)):
        
        if len(t) == 0:
            t = r['word']
            tag = r['entity']
            continue
        if "#" in r['word']:
            t = t + str(r['word']).replace("#","")
            continue
        if "#" not in r['word'] and len(t) != 0:
            temp.append(t)
            tags_joined.append(tag)
            t = r['word']
            tag = r['entity']

    for tag in Y_test[idx].split():
        if tag == "O":
            out_tags.append("O")
        if tag != "O" and len(tags_joined) > 0:
            out_tags.append(tags_joined[0])
            tags_joined = tags_joined[1:]
            continue
        if tag != "O" and len(tags_joined) == 0:
            out_tags.append("O")

    #print(len(Y_test[idx].split()), len(out_tags))

    out_tags = " ".join(out_tags).replace("I-","B-").split()
        
    last_tag = out_tags[0]
    tags_corrected.append(last_tag)
    
    for tag in out_tags[1:]:

        if tag == last_tag:
            tags_corrected.append(tag.replace("B-","I-"))
            last_tag = tag
        else:
            last_tag = tag
            tags_corrected.append(tag)
            
    #print(len(Y_test[idx].split()), len(tags_corrected))

    full_tags.append(tags_corrected)

print(len(full_tags))

# for idx, el in tqdm(enumerate(full_tags)):
#     if len(el) != len(len(Y_test[idx].split())):
#         print("Somethings wrong sir")
   
Some weights of the model checkpoint at dbmdz/bert-large-cased-finetuned-conll03-english were not used when initializing BertForTokenClassification: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
0it [00:00, ?it/s]
215
with open("out.tsv", 'w') as file:
    for el in full_tags:
    
        file.write(" ".join(el)) 
        file.write(f"\n")