transformer_pipeline/en-ner-conll-2003/skrypcik_v2.py

87 lines
3.2 KiB
Python
Raw Normal View History

2024-06-03 12:26:40 +02:00
import pandas as pd
def correct_labels(input_file, output_file):
df = pd.read_csv(input_file, sep="\t", names=["Text"])
corrected_lines = []
for line in df["Text"]:
tokens = line.split(" ")
corrected_tokens = []
previous_token = "O"
token = tokens[0]
for t in range(len(tokens)):
if t + 1 < len(tokens):
next_token = tokens[t + 1]
else:
next_token = "O"
if token == "I-ORG" and previous_token != "B-ORG" and previous_token != "I-ORG":
if next_token == "I-ORG":
corrected_token = "B-ORG"
elif previous_token == "B-PER":
corrected_token = "I-PER"
elif previous_token == "B-LOC":
corrected_token = "I-LOC"
elif previous_token == "B-MISC":
corrected_token = "I-MISC"
else:
corrected_token = "B-ORG"
elif token == "I-PER" and previous_token != "B-PER" and previous_token != "I-PER":
if next_token == "I-PER":
corrected_token = "B-PER"
elif previous_token == "B-ORG":
corrected_token = "I-ORG"
elif previous_token == "B-LOC":
corrected_token = "I-LOC"
elif previous_token == "B-MISC":
corrected_token = "I-MISC"
else:
corrected_token = "B-PER"
elif token == "I-LOC" and previous_token != "B-LOC" and previous_token != "I-LOC":
if next_token == "I-LOC":
corrected_token = "B-LOC"
elif previous_token == "B-ORG":
corrected_token = "I-ORG"
elif previous_token == "B-PER":
corrected_token = "I-PER"
elif previous_token == "B-MISC":
corrected_token = "I-MISC"
else:
corrected_token = "B-LOC"
elif token == "I-MISC" and previous_token != "B-MISC" and previous_token != "I-MISC":
if next_token == "I-MISC":
corrected_token = "B-MISC"
elif previous_token == "B-ORG":
corrected_token = "I-ORG"
elif previous_token == "B-PER":
corrected_token = "I-PER"
elif previous_token == "B-LOC":
corrected_token = "I-LOC"
else:
corrected_token = "B-MISC"
else:
corrected_token = token
corrected_tokens.append(corrected_token)
previous_token = corrected_token
token = next_token
corrected_line = " ".join(corrected_tokens)
corrected_lines.append(corrected_line)
df["Text"] = corrected_lines
df.to_csv(output_file, sep="\t", index=False, header=False)
input_file = "dev-0/out.tsv"
output_file = "dev-0/out_v2.tsv"
correct_labels(input_file, output_file)
input_file = "test-A/out.tsv"
output_file = "test-A/out_v2.tsv"
correct_labels(input_file, output_file)