Co-authored-by: Alexxiia <Alexxiia@users.noreply.github.com>
This commit is contained in:
Marcin Czerniak 2023-06-21 01:52:12 +02:00
parent 6c4ccfba4b
commit 04464285bd
1 changed files with 68 additions and 0 deletions

68
train.py Normal file
View File

@ -0,0 +1,68 @@
from flair.data import Sentence
from flair.models import SequenceTagger
from tqdm import tqdm
# --------------------- VARIABLES ---------------------- #
# ----------------- DATA PREPROCESSING ----------------- #
# def convert_to_BIO_format(directory):
# """
# Read data in BIO format.
# """
# _in = open(directory + "/in.tsv", "r").read().split(" ")
# _expected = open(directory + "/expected.tsv", "r").read().split(" ")
# _lines = list(map(lambda x: f"{x[0]} {x[1]}", zip(_in, _expected)))
# _lines = list(map(lambda x: "" if x.startswith("</S>") else x, _lines))
# open(directory + "/in_bio.tsv", "w").write("\n".join(_lines))
# convert_to_BIO_format("dev-0")
# ---------------------- GENERATE PREDICTIONS ---------------------- #
DIRECTORY = "dev-0"
# load tagger
tagger = SequenceTagger.load("flair/ner-english")
batches = open(f"{DIRECTORY}/in.tsv", "r").read().split("\n")[:-1]
results = []
with tqdm(total=len(batches)) as pbar:
for batch in batches:
batch_result = []
for sentence in batch.split("</S>"):
raw_sentence = sentence
sentence = sentence.strip()
if sentence == "":
continue
sentence = Sentence(sentence, use_tokenizer=False)
tagger.predict(sentence)
out = ["O"] * len(sentence)
for entity in sentence.get_spans('ner'):
is_first = True
for token in entity:
if is_first:
out[token.idx - 1] = f"B-{entity.tag}"
else:
out[token.idx - 1] = f"I-{entity.tag}"
is_first = False
batch_result.append(" ".join(out))
results.append(" O ".join(batch_result) + " O")
pbar.update(1)
open(f"{DIRECTORY}/out.tsv", "w").write("\n".join(results))
# 456