uczenie_glebokie/transformer/transformer.py
Kacper Kalinowski 22add9d0af transformer
2024-06-09 23:10:50 +02:00

110 lines
4.0 KiB
Python

import pandas as pd
import numpy as np
from transformers import pipeline
from sklearn.metrics import classification_report
import sys
# Set UTF-8 encoding
sys.stdout.reconfigure(encoding='utf-8')
# Load data
def load_data(file_path):
with open(file_path, 'r', encoding='utf-8') as file:
lines = file.readlines()
sentences = [line.strip() for line in lines]
return sentences
train_data = pd.read_csv('./train/train.tsv', sep='\t', header=None, names=['label', 'sentence'], encoding='utf-8')
dev_sentences = load_data('./dev-0/in.tsv')
dev_labels = load_data('./dev-0/expected.tsv')
test_sentences = load_data('./test-A/in.tsv')
# Preprocess data
def preprocess_data(sentences, labels=None):
tokenized_sentences = [sentence.split() for sentence in sentences]
if labels is not None:
tokenized_labels = [label.split() for label in labels]
return tokenized_sentences, tokenized_labels
return tokenized_sentences
train_sentences, train_labels = preprocess_data(train_data['sentence'].values, train_data['label'].values)
dev_sentences, dev_labels = preprocess_data(dev_sentences, dev_labels)
test_sentences = preprocess_data(test_sentences)
# Define NER pipeline
ner_pipeline = pipeline(task="ner", model="dbmdz/bert-large-cased-finetuned-conll03-english", aggregation_strategy="simple")
# Predict NER tags
def predict_ner_tags(sentences):
ner_tags = []
for sentence in sentences:
ner_result = ner_pipeline(" ".join(sentence))
tags = ['O'] * len(sentence)
for entity in ner_result:
entity_type = entity['entity_group']
start_idx, end_idx = entity['start'], entity['end']
token_start_idx = None
token_end_idx = None
char_idx = 0
for i, token in enumerate(sentence):
token_start = char_idx
token_end = char_idx + len(token)
if token_start <= start_idx < token_end:
token_start_idx = i
if token_start < end_idx <= token_end:
token_end_idx = i
break
char_idx = token_end + 1
if token_start_idx is not None and token_end_idx is not None:
tags[token_start_idx] = f'B-{entity_type}'
for i in range(token_start_idx + 1, token_end_idx + 1):
tags[i] = f'I-{entity_type}'
ner_tags.append(tags)
return ner_tags
# Correct IOB labels function
def correct_iob_labels(predictions):
corrected = []
for pred in predictions:
corrected_sentence = []
prev_label = 'O'
for label in pred:
if label.startswith('I-') and (prev_label == 'O' or prev_label[2:] != label[2:]):
corrected_sentence.append('B-' + label[2:])
else:
corrected_sentence.append(label)
prev_label = corrected_sentence[-1]
corrected.append(corrected_sentence)
return corrected
# Get and correct predictions for dev and test sets
dev_pred_tags = correct_iob_labels(predict_ner_tags(dev_sentences))
test_pred_tags = correct_iob_labels(predict_ner_tags(test_sentences))
# Flatten the lists for evaluation
def flatten_labels(labels):
return [item for sublist in labels for item in sublist]
dev_true_tags = flatten_labels(dev_labels)
dev_pred_tags_flat = flatten_labels(dev_pred_tags)
# Print the classification report using UTF-8 encoding
print(classification_report(
dev_true_tags,
dev_pred_tags_flat,
labels=list(set(dev_true_tags)), # Use the set of true labels
target_names=list(set(dev_true_tags))
))
# Save dev predictions
dev_predictions = [' '.join(tags) for tags in dev_pred_tags]
with open('./dev-0/out.tsv', 'w', encoding='utf-8') as f:
for prediction in dev_predictions:
f.write("%s\n" % prediction)
# Save test predictions
test_predictions = [' '.join(tags) for tags in test_pred_tags]
with open('./test-A/out.tsv', 'w', encoding='utf-8') as f:
for prediction in test_predictions:
f.write("%s\n" % prediction)