WIP
Co-authored-by: Alexxiia <Alexxiia@users.noreply.github.com>
This commit is contained in:
parent
6c4ccfba4b
commit
04464285bd
68
train.py
Normal file
68
train.py
Normal file
@ -0,0 +1,68 @@
|
|||||||
|
from flair.data import Sentence
|
||||||
|
from flair.models import SequenceTagger
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
# --------------------- VARIABLES ---------------------- #
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# ----------------- DATA PREPROCESSING ----------------- #
|
||||||
|
|
||||||
|
# def convert_to_BIO_format(directory):
|
||||||
|
# """
|
||||||
|
# Read data in BIO format.
|
||||||
|
# """
|
||||||
|
# _in = open(directory + "/in.tsv", "r").read().split(" ")
|
||||||
|
# _expected = open(directory + "/expected.tsv", "r").read().split(" ")
|
||||||
|
|
||||||
|
# _lines = list(map(lambda x: f"{x[0]} {x[1]}", zip(_in, _expected)))
|
||||||
|
# _lines = list(map(lambda x: "" if x.startswith("</S>") else x, _lines))
|
||||||
|
|
||||||
|
# open(directory + "/in_bio.tsv", "w").write("\n".join(_lines))
|
||||||
|
|
||||||
|
# convert_to_BIO_format("dev-0")
|
||||||
|
|
||||||
|
# ---------------------- GENERATE PREDICTIONS ---------------------- #
|
||||||
|
DIRECTORY = "dev-0"
|
||||||
|
|
||||||
|
# load tagger
|
||||||
|
tagger = SequenceTagger.load("flair/ner-english")
|
||||||
|
batches = open(f"{DIRECTORY}/in.tsv", "r").read().split("\n")[:-1]
|
||||||
|
results = []
|
||||||
|
|
||||||
|
with tqdm(total=len(batches)) as pbar:
|
||||||
|
for batch in batches:
|
||||||
|
batch_result = []
|
||||||
|
|
||||||
|
for sentence in batch.split("</S>"):
|
||||||
|
raw_sentence = sentence
|
||||||
|
sentence = sentence.strip()
|
||||||
|
|
||||||
|
if sentence == "":
|
||||||
|
continue
|
||||||
|
|
||||||
|
sentence = Sentence(sentence, use_tokenizer=False)
|
||||||
|
|
||||||
|
tagger.predict(sentence)
|
||||||
|
|
||||||
|
out = ["O"] * len(sentence)
|
||||||
|
|
||||||
|
for entity in sentence.get_spans('ner'):
|
||||||
|
|
||||||
|
is_first = True
|
||||||
|
for token in entity:
|
||||||
|
if is_first:
|
||||||
|
out[token.idx - 1] = f"B-{entity.tag}"
|
||||||
|
else:
|
||||||
|
out[token.idx - 1] = f"I-{entity.tag}"
|
||||||
|
|
||||||
|
is_first = False
|
||||||
|
batch_result.append(" ".join(out))
|
||||||
|
|
||||||
|
results.append(" O ".join(batch_result) + " O")
|
||||||
|
|
||||||
|
pbar.update(1)
|
||||||
|
|
||||||
|
open(f"{DIRECTORY}/out.tsv", "w").write("\n".join(results))
|
||||||
|
|
||||||
|
# 456
|
Loading…
Reference in New Issue
Block a user