WIP
Co-authored-by: Alexxiia <Alexxiia@users.noreply.github.com>
This commit is contained in:
parent
6c4ccfba4b
commit
04464285bd
68
train.py
Normal file
68
train.py
Normal file
@ -0,0 +1,68 @@
|
||||
from flair.data import Sentence
|
||||
from flair.models import SequenceTagger
|
||||
from tqdm import tqdm
|
||||
|
||||
# --------------------- VARIABLES ---------------------- #
|
||||
|
||||
|
||||
|
||||
# ----------------- DATA PREPROCESSING ----------------- #
|
||||
|
||||
# def convert_to_BIO_format(directory):
|
||||
# """
|
||||
# Read data in BIO format.
|
||||
# """
|
||||
# _in = open(directory + "/in.tsv", "r").read().split(" ")
|
||||
# _expected = open(directory + "/expected.tsv", "r").read().split(" ")
|
||||
|
||||
# _lines = list(map(lambda x: f"{x[0]} {x[1]}", zip(_in, _expected)))
|
||||
# _lines = list(map(lambda x: "" if x.startswith("</S>") else x, _lines))
|
||||
|
||||
# open(directory + "/in_bio.tsv", "w").write("\n".join(_lines))
|
||||
|
||||
# convert_to_BIO_format("dev-0")
|
||||
|
||||
# ---------------------- GENERATE PREDICTIONS ---------------------- #
|
||||
DIRECTORY = "dev-0"
|
||||
|
||||
# load tagger
|
||||
tagger = SequenceTagger.load("flair/ner-english")
|
||||
batches = open(f"{DIRECTORY}/in.tsv", "r").read().split("\n")[:-1]
|
||||
results = []
|
||||
|
||||
with tqdm(total=len(batches)) as pbar:
|
||||
for batch in batches:
|
||||
batch_result = []
|
||||
|
||||
for sentence in batch.split("</S>"):
|
||||
raw_sentence = sentence
|
||||
sentence = sentence.strip()
|
||||
|
||||
if sentence == "":
|
||||
continue
|
||||
|
||||
sentence = Sentence(sentence, use_tokenizer=False)
|
||||
|
||||
tagger.predict(sentence)
|
||||
|
||||
out = ["O"] * len(sentence)
|
||||
|
||||
for entity in sentence.get_spans('ner'):
|
||||
|
||||
is_first = True
|
||||
for token in entity:
|
||||
if is_first:
|
||||
out[token.idx - 1] = f"B-{entity.tag}"
|
||||
else:
|
||||
out[token.idx - 1] = f"I-{entity.tag}"
|
||||
|
||||
is_first = False
|
||||
batch_result.append(" ".join(out))
|
||||
|
||||
results.append(" O ".join(batch_result) + " O")
|
||||
|
||||
pbar.update(1)
|
||||
|
||||
open(f"{DIRECTORY}/out.tsv", "w").write("\n".join(results))
|
||||
|
||||
# 456
|
Loading…
Reference in New Issue
Block a user