67 lines
1.6 KiB
Python
67 lines
1.6 KiB
Python
from flair.data import Sentence
|
|
from flair.models import SequenceTagger
|
|
from tqdm import tqdm
|
|
import sys
|
|
|
|
# --------------------- VARIABLES ---------------------- #
|
|
|
|
|
|
|
|
# ----------------- DATA PREPROCESSING ----------------- #
|
|
|
|
# def convert_to_BIO_format(directory):
|
|
# """
|
|
# Read data in BIO format.
|
|
# """
|
|
# _in = open(directory + "/in.tsv", "r").read().split(" ")
|
|
# _expected = open(directory + "/expected.tsv", "r").read().split(" ")
|
|
|
|
# _lines = list(map(lambda x: f"{x[0]} {x[1]}", zip(_in, _expected)))
|
|
# _lines = list(map(lambda x: "" if x.startswith("</S>") else x, _lines))
|
|
|
|
# open(directory + "/in_bio.tsv", "w").write("\n".join(_lines))
|
|
|
|
# convert_to_BIO_format("dev-0")
|
|
|
|
# ---------------------- GENERATE PREDICTIONS ---------------------- #
|
|
DIRECTORY = sys.argv[1]
|
|
|
|
tagger = SequenceTagger.load("flair/ner-english")
|
|
batches = open(f"{DIRECTORY}/in.tsv", "r").read().split("\n")[:-1]
|
|
results = []
|
|
|
|
with tqdm(total=len(batches)) as pbar:
|
|
for batch in batches:
|
|
batch_result = []
|
|
|
|
for sentence in batch.split("</S>"):
|
|
raw_sentence = sentence
|
|
sentence = sentence.strip()
|
|
|
|
if sentence == "":
|
|
continue
|
|
|
|
sentence = Sentence(sentence, use_tokenizer=False)
|
|
|
|
tagger.predict(sentence)
|
|
|
|
out = ["O"] * len(sentence)
|
|
|
|
for entity in sentence.get_spans('ner'):
|
|
|
|
is_first = True
|
|
for token in entity:
|
|
if is_first:
|
|
out[token.idx - 1] = f"B-{entity.tag}"
|
|
else:
|
|
out[token.idx - 1] = f"I-{entity.tag}"
|
|
|
|
is_first = False
|
|
batch_result.append(" ".join(out))
|
|
|
|
results.append(" O ".join(batch_result) + " O")
|
|
|
|
pbar.update(1)
|
|
|
|
open(f"{DIRECTORY}/out.tsv", "w").write("\n".join(results))
|