import pandas as pd def correct_labels(input_file, output_file): df = pd.read_csv(input_file, sep="\t", names=["Text"]) corrected_lines = [] for line in df["Text"]: tokens = line.split(" ") corrected_tokens = [] previous_token = "O" token = tokens[0] for t in range(len(tokens)): if t + 1 < len(tokens): next_token = tokens[t + 1] else: next_token = "O" if token == "I-ORG" and previous_token != "B-ORG" and previous_token != "I-ORG": if next_token == "I-ORG": corrected_token = "B-ORG" elif previous_token == "B-PER": corrected_token = "I-PER" elif previous_token == "B-LOC": corrected_token = "I-LOC" elif previous_token == "B-MISC": corrected_token = "I-MISC" else: corrected_token = "B-ORG" elif token == "I-PER" and previous_token != "B-PER" and previous_token != "I-PER": if next_token == "I-PER": corrected_token = "B-PER" elif previous_token == "B-ORG": corrected_token = "I-ORG" elif previous_token == "B-LOC": corrected_token = "I-LOC" elif previous_token == "B-MISC": corrected_token = "I-MISC" else: corrected_token = "B-PER" elif token == "I-LOC" and previous_token != "B-LOC" and previous_token != "I-LOC": if next_token == "I-LOC": corrected_token = "B-LOC" elif previous_token == "B-ORG": corrected_token = "I-ORG" elif previous_token == "B-PER": corrected_token = "I-PER" elif previous_token == "B-MISC": corrected_token = "I-MISC" else: corrected_token = "B-LOC" elif token == "I-MISC" and previous_token != "B-MISC" and previous_token != "I-MISC": if next_token == "I-MISC": corrected_token = "B-MISC" elif previous_token == "B-ORG": corrected_token = "I-ORG" elif previous_token == "B-PER": corrected_token = "I-PER" elif previous_token == "B-LOC": corrected_token = "I-LOC" else: corrected_token = "B-MISC" else: corrected_token = token corrected_tokens.append(corrected_token) previous_token = corrected_token token = next_token corrected_line = " ".join(corrected_tokens) corrected_lines.append(corrected_line) df["Text"] = corrected_lines df.to_csv(output_file, sep="\t", index=False, header=False) input_file = "dev-0/out.tsv" output_file = "dev-0/out_v2.tsv" correct_labels(input_file, output_file) input_file = "test-A/out.tsv" output_file = "test-A/out_v2.tsv" correct_labels(input_file, output_file)