5.8 KiB
5.8 KiB
import pandas as pd
from transformers import pipeline
# Wczytanie danych
train_data = pd.read_csv('en-ner-conll-2003/train/train.tsv.xz', delimiter='\t', header=None, compression='xz')
in_data_dev0 = pd.read_csv('en-ner-conll-2003/dev-0/in.tsv', delimiter='\t', header=None)
expected_data = pd.read_csv('en-ner-conll-2003/dev-0/expected.tsv', delimiter='\t', header=None)
in_data_testA = pd.read_csv('en-ner-conll-2003/test-A/in.tsv', delimiter='\t', header=None)
# Inicjalizacja pipeline NER
ner_pipeline = pipeline("ner", grouped_entities=True)
# Przetwarzanie danych z pliku in.tsv
sentences_dev0 = in_data_dev0[0].tolist()
senteces_testA = in_data_testA[0].tolist()
ner_results_dev = ner_pipeline(sentences_dev0)
ner_results_testA = ner_pipeline(senteces_testA)
# Funkcja do mapowania wyników NER na format B-XXX, I-XXX, O
def map_ner_results(ner_results, sentences):
ner_labels = []
for sentence, entities in zip(sentences, ner_results):
words = sentence.split()
labels = ['O'] * len(words)
for entity in entities:
start_idx = entity['start']
end_idx = entity['end']
entity_label = entity['entity_group']
entity_words = sentence[start_idx:end_idx].split()
start_word_idx = len(sentence[:start_idx].split())
end_word_idx = start_word_idx + len(entity_words)
if start_word_idx < len(labels) and end_word_idx <= len(labels):
labels[start_word_idx] = f'B-{entity_label}'
for i in range(start_word_idx + 1, end_word_idx):
labels[i] = f'I-{entity_label}'
ner_labels.append(labels)
return ner_labels
# Mapowanie wyników NER na odpowiednie etykiety
predicted_labels_dev0 = map_ner_results(ner_results_dev, sentences_dev0)
predicted_labels_testA = map_ner_results(ner_results_testA, senteces_testA)
# Konwersja listy etykiet na format ciągu znaków
predicted_strings_dev0 = [' '.join(labels) for labels in predicted_labels_dev0]
predicted_strings_testA = [' '.join(labels) for labels in predicted_labels_testA]
expected_strings = expected_data[0].tolist()
# Zapisanie wyników do pliku out.tsv
with open('en-ner-conll-2003/dev-0/out.tsv', 'w') as f:
for line in predicted_strings_dev0:
f.write(line + '\n')
with open('en-ner-conll-2003/test-A/out.tsv', 'w') as f:
for line in predicted_strings_testA:
f.write(line + '\n')
# Sprawdzenie zgodności wyników
correct = 0
total = 0
for pred, exp in zip(predicted_strings_dev0, expected_strings):
pred_labels = pred.split()
exp_labels = exp.split()
for p, e in zip(pred_labels, exp_labels):
if p == e:
correct += 1
total += 1
accuracy = correct / total
print(f"Accuracy for dev0: {accuracy:.2%}")
No model was supplied, defaulted to dbmdz/bert-large-cased-finetuned-conll03-english and revision f2482bf (https://huggingface.co/dbmdz/bert-large-cased-finetuned-conll03-english). Using a pipeline without specifying a model name and revision in production is not recommended. c:\Users\walcz\Desktop\studia\uczenie\RNN\en-ner-conll-2003\myenv\Lib\site-packages\huggingface_hub\file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`. warnings.warn( Some weights of the model checkpoint at dbmdz/bert-large-cased-finetuned-conll03-english were not used when initializing BertForTokenClassification: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight'] - This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model). - This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model). c:\Users\walcz\Desktop\studia\uczenie\RNN\en-ner-conll-2003\myenv\Lib\site-packages\transformers\pipelines\token_classification.py:168: UserWarning: `grouped_entities` is deprecated and will be removed in version v5.0.0, defaulted to `aggregation_strategy="AggregationStrategy.SIMPLE"` instead. warnings.warn(
Accuracy for dev0: 94.88%