en-ner-conll-2003/train.py

67 lines
1.6 KiB
Python

from flair.data import Sentence
from flair.models import SequenceTagger
from tqdm import tqdm
import sys
# --------------------- VARIABLES ---------------------- #
# ----------------- DATA PREPROCESSING ----------------- #
# def convert_to_BIO_format(directory):
# """
# Read data in BIO format.
# """
# _in = open(directory + "/in.tsv", "r").read().split(" ")
# _expected = open(directory + "/expected.tsv", "r").read().split(" ")
# _lines = list(map(lambda x: f"{x[0]} {x[1]}", zip(_in, _expected)))
# _lines = list(map(lambda x: "" if x.startswith("</S>") else x, _lines))
# open(directory + "/in_bio.tsv", "w").write("\n".join(_lines))
# convert_to_BIO_format("dev-0")
# ---------------------- GENERATE PREDICTIONS ---------------------- #
DIRECTORY = sys.argv[1]
tagger = SequenceTagger.load("flair/ner-english")
batches = open(f"{DIRECTORY}/in.tsv", "r").read().split("\n")[:-1]
results = []
with tqdm(total=len(batches)) as pbar:
for batch in batches:
batch_result = []
for sentence in batch.split("</S>"):
raw_sentence = sentence
sentence = sentence.strip()
if sentence == "":
continue
sentence = Sentence(sentence, use_tokenizer=False)
tagger.predict(sentence)
out = ["O"] * len(sentence)
for entity in sentence.get_spans('ner'):
is_first = True
for token in entity:
if is_first:
out[token.idx - 1] = f"B-{entity.tag}"
else:
out[token.idx - 1] = f"I-{entity.tag}"
is_first = False
batch_result.append(" ".join(out))
results.append(" O ".join(batch_result) + " O")
pbar.update(1)
open(f"{DIRECTORY}/out.tsv", "w").write("\n".join(results))