110 lines
4.0 KiB
Python
110 lines
4.0 KiB
Python
import pandas as pd
|
|
import numpy as np
|
|
from transformers import pipeline
|
|
from sklearn.metrics import classification_report
|
|
import sys
|
|
|
|
# Set UTF-8 encoding
|
|
sys.stdout.reconfigure(encoding='utf-8')
|
|
|
|
# Load data
|
|
def load_data(file_path):
|
|
with open(file_path, 'r', encoding='utf-8') as file:
|
|
lines = file.readlines()
|
|
sentences = [line.strip() for line in lines]
|
|
return sentences
|
|
|
|
train_data = pd.read_csv('./train/train.tsv', sep='\t', header=None, names=['label', 'sentence'], encoding='utf-8')
|
|
dev_sentences = load_data('./dev-0/in.tsv')
|
|
dev_labels = load_data('./dev-0/expected.tsv')
|
|
test_sentences = load_data('./test-A/in.tsv')
|
|
|
|
# Preprocess data
|
|
def preprocess_data(sentences, labels=None):
|
|
tokenized_sentences = [sentence.split() for sentence in sentences]
|
|
if labels is not None:
|
|
tokenized_labels = [label.split() for label in labels]
|
|
return tokenized_sentences, tokenized_labels
|
|
return tokenized_sentences
|
|
|
|
train_sentences, train_labels = preprocess_data(train_data['sentence'].values, train_data['label'].values)
|
|
dev_sentences, dev_labels = preprocess_data(dev_sentences, dev_labels)
|
|
test_sentences = preprocess_data(test_sentences)
|
|
|
|
# Define NER pipeline
|
|
ner_pipeline = pipeline(task="ner", model="dbmdz/bert-large-cased-finetuned-conll03-english", aggregation_strategy="simple")
|
|
|
|
# Predict NER tags
|
|
def predict_ner_tags(sentences):
|
|
ner_tags = []
|
|
for sentence in sentences:
|
|
ner_result = ner_pipeline(" ".join(sentence))
|
|
tags = ['O'] * len(sentence)
|
|
for entity in ner_result:
|
|
entity_type = entity['entity_group']
|
|
start_idx, end_idx = entity['start'], entity['end']
|
|
token_start_idx = None
|
|
token_end_idx = None
|
|
char_idx = 0
|
|
for i, token in enumerate(sentence):
|
|
token_start = char_idx
|
|
token_end = char_idx + len(token)
|
|
if token_start <= start_idx < token_end:
|
|
token_start_idx = i
|
|
if token_start < end_idx <= token_end:
|
|
token_end_idx = i
|
|
break
|
|
char_idx = token_end + 1
|
|
if token_start_idx is not None and token_end_idx is not None:
|
|
tags[token_start_idx] = f'B-{entity_type}'
|
|
for i in range(token_start_idx + 1, token_end_idx + 1):
|
|
tags[i] = f'I-{entity_type}'
|
|
ner_tags.append(tags)
|
|
return ner_tags
|
|
|
|
# Correct IOB labels function
|
|
def correct_iob_labels(predictions):
|
|
corrected = []
|
|
for pred in predictions:
|
|
corrected_sentence = []
|
|
prev_label = 'O'
|
|
for label in pred:
|
|
if label.startswith('I-') and (prev_label == 'O' or prev_label[2:] != label[2:]):
|
|
corrected_sentence.append('B-' + label[2:])
|
|
else:
|
|
corrected_sentence.append(label)
|
|
prev_label = corrected_sentence[-1]
|
|
corrected.append(corrected_sentence)
|
|
return corrected
|
|
|
|
# Get and correct predictions for dev and test sets
|
|
dev_pred_tags = correct_iob_labels(predict_ner_tags(dev_sentences))
|
|
test_pred_tags = correct_iob_labels(predict_ner_tags(test_sentences))
|
|
|
|
# Flatten the lists for evaluation
|
|
def flatten_labels(labels):
|
|
return [item for sublist in labels for item in sublist]
|
|
|
|
dev_true_tags = flatten_labels(dev_labels)
|
|
dev_pred_tags_flat = flatten_labels(dev_pred_tags)
|
|
|
|
# Print the classification report using UTF-8 encoding
|
|
print(classification_report(
|
|
dev_true_tags,
|
|
dev_pred_tags_flat,
|
|
labels=list(set(dev_true_tags)), # Use the set of true labels
|
|
target_names=list(set(dev_true_tags))
|
|
))
|
|
|
|
# Save dev predictions
|
|
dev_predictions = [' '.join(tags) for tags in dev_pred_tags]
|
|
with open('./dev-0/out.tsv', 'w', encoding='utf-8') as f:
|
|
for prediction in dev_predictions:
|
|
f.write("%s\n" % prediction)
|
|
|
|
# Save test predictions
|
|
test_predictions = [' '.join(tags) for tags in test_pred_tags]
|
|
with open('./test-A/out.tsv', 'w', encoding='utf-8') as f:
|
|
for prediction in test_predictions:
|
|
f.write("%s\n" % prediction)
|