diff --git a/train.py b/train.py new file mode 100644 index 0000000..c6a463b --- /dev/null +++ b/train.py @@ -0,0 +1,68 @@ +from flair.data import Sentence +from flair.models import SequenceTagger +from tqdm import tqdm + +# --------------------- VARIABLES ---------------------- # + + + +# ----------------- DATA PREPROCESSING ----------------- # + +# def convert_to_BIO_format(directory): +# """ +# Read data in BIO format. +# """ +# _in = open(directory + "/in.tsv", "r").read().split(" ") +# _expected = open(directory + "/expected.tsv", "r").read().split(" ") + +# _lines = list(map(lambda x: f"{x[0]} {x[1]}", zip(_in, _expected))) +# _lines = list(map(lambda x: "" if x.startswith("") else x, _lines)) + +# open(directory + "/in_bio.tsv", "w").write("\n".join(_lines)) + +# convert_to_BIO_format("dev-0") + +# ---------------------- GENERATE PREDICTIONS ---------------------- # +DIRECTORY = "dev-0" + +# load tagger +tagger = SequenceTagger.load("flair/ner-english") +batches = open(f"{DIRECTORY}/in.tsv", "r").read().split("\n")[:-1] +results = [] + +with tqdm(total=len(batches)) as pbar: + for batch in batches: + batch_result = [] + + for sentence in batch.split(""): + raw_sentence = sentence + sentence = sentence.strip() + + if sentence == "": + continue + + sentence = Sentence(sentence, use_tokenizer=False) + + tagger.predict(sentence) + + out = ["O"] * len(sentence) + + for entity in sentence.get_spans('ner'): + + is_first = True + for token in entity: + if is_first: + out[token.idx - 1] = f"B-{entity.tag}" + else: + out[token.idx - 1] = f"I-{entity.tag}" + + is_first = False + batch_result.append(" ".join(out)) + + results.append(" O ".join(batch_result) + " O") + + pbar.update(1) + +open(f"{DIRECTORY}/out.tsv", "w").write("\n".join(results)) + +# 456