Transformer/transformer.ipynb
2024-06-06 23:12:09 +02:00

5.8 KiB

import pandas as pd
from transformers import pipeline

# Wczytanie danych
train_data = pd.read_csv('en-ner-conll-2003/train/train.tsv.xz', delimiter='\t', header=None, compression='xz')
in_data_dev0 = pd.read_csv('en-ner-conll-2003/dev-0/in.tsv', delimiter='\t', header=None)
expected_data = pd.read_csv('en-ner-conll-2003/dev-0/expected.tsv', delimiter='\t', header=None)
in_data_testA = pd.read_csv('en-ner-conll-2003/test-A/in.tsv', delimiter='\t', header=None)

# Inicjalizacja pipeline NER
ner_pipeline = pipeline("ner", grouped_entities=True)

# Przetwarzanie danych z pliku in.tsv
sentences_dev0 = in_data_dev0[0].tolist()
senteces_testA = in_data_testA[0].tolist()

ner_results_dev = ner_pipeline(sentences_dev0)
ner_results_testA = ner_pipeline(senteces_testA)

# Funkcja do mapowania wyników NER na format B-XXX, I-XXX, O
def map_ner_results(ner_results, sentences):
    ner_labels = []
    for sentence, entities in zip(sentences, ner_results):
        words = sentence.split()
        labels = ['O'] * len(words)
        for entity in entities:
            start_idx = entity['start']
            end_idx = entity['end']
            entity_label = entity['entity_group']
            entity_words = sentence[start_idx:end_idx].split()
            start_word_idx = len(sentence[:start_idx].split())
            end_word_idx = start_word_idx + len(entity_words)
            if start_word_idx < len(labels) and end_word_idx <= len(labels):
                labels[start_word_idx] = f'B-{entity_label}'
                for i in range(start_word_idx + 1, end_word_idx):
                    labels[i] = f'I-{entity_label}'
        ner_labels.append(labels)
    return ner_labels

# Mapowanie wyników NER na odpowiednie etykiety
predicted_labels_dev0 = map_ner_results(ner_results_dev, sentences_dev0)
predicted_labels_testA = map_ner_results(ner_results_testA, senteces_testA)

# Konwersja listy etykiet na format ciągu znaków
predicted_strings_dev0 = [' '.join(labels) for labels in predicted_labels_dev0]
predicted_strings_testA = [' '.join(labels) for labels in predicted_labels_testA]

expected_strings = expected_data[0].tolist()

# Zapisanie wyników do pliku out.tsv
with open('en-ner-conll-2003/dev-0/out.tsv', 'w') as f:
    for line in predicted_strings_dev0:
        f.write(line + '\n')

with open('en-ner-conll-2003/test-A/out.tsv', 'w') as f:
    for line in predicted_strings_testA:
        f.write(line + '\n')


# Sprawdzenie zgodności wyników
correct = 0
total = 0
for pred, exp in zip(predicted_strings_dev0, expected_strings):
    pred_labels = pred.split()
    exp_labels = exp.split()
    for p, e in zip(pred_labels, exp_labels):
        if p == e:
            correct += 1
        total += 1

accuracy = correct / total
print(f"Accuracy for dev0: {accuracy:.2%}")
No model was supplied, defaulted to dbmdz/bert-large-cased-finetuned-conll03-english and revision f2482bf (https://huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english).
Using a pipeline without specifying a model name and revision in production is not recommended.
c:\Users\walcz\Desktop\studia\uczenie\RNN\en-ner-conll-2003\myenv\Lib\site-packages\huggingface_hub\file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
Some weights of the model checkpoint at dbmdz/bert-large-cased-finetuned-conll03-english were not used when initializing BertForTokenClassification: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
c:\Users\walcz\Desktop\studia\uczenie\RNN\en-ner-conll-2003\myenv\Lib\site-packages\transformers\pipelines\token_classification.py:168: UserWarning: `grouped_entities` is deprecated and will be removed in version v5.0.0, defaulted to `aggregation_strategy="AggregationStrategy.SIMPLE"` instead.
  warnings.warn(
Accuracy for dev0: 94.88%