2024-05-10 01:25:13 +02:00
|
|
|
from conllu import parse_incr
|
|
|
|
from flair.data import Corpus
|
|
|
|
from flair.embeddings import StackedEmbeddings
|
|
|
|
from flair.embeddings import WordEmbeddings
|
|
|
|
from flair.embeddings import CharacterEmbeddings
|
|
|
|
from flair.embeddings import FlairEmbeddings
|
|
|
|
from flair.models import SequenceTagger
|
|
|
|
from flair.trainers import ModelTrainer
|
2024-05-22 23:45:33 +02:00
|
|
|
from src.utils.nlu_utils import conllu2flair, nolabel2o
|
2024-05-10 01:25:13 +02:00
|
|
|
|
|
|
|
import torch
|
|
|
|
if torch.cuda.is_available():
|
|
|
|
torch.backends.cudnn.enabled = False
|
|
|
|
torch.backends.cudnn.benchmark = False
|
|
|
|
torch.backends.cudnn.deterministic = True
|
|
|
|
|
|
|
|
def train_model(label_type, field_parsers = {}):
|
|
|
|
with open('data/train_dialog.conllu', encoding='utf-8') as f:
|
|
|
|
trainset = list(parse_incr(f, fields=['id', 'form', 'frame', 'slot'], field_parsers=field_parsers))
|
|
|
|
with open('data/test_dialog_46.conllu', encoding='utf-8') as f:
|
|
|
|
testset = list(parse_incr(f, fields=['id', 'form', 'frame', 'slot'], field_parsers=field_parsers))
|
|
|
|
|
|
|
|
corpus = Corpus(train=conllu2flair(trainset, label_type), test=conllu2flair(testset, label_type))
|
|
|
|
label_dictionary = corpus.make_label_dictionary(label_type=label_type)
|
|
|
|
|
|
|
|
embedding_types = [
|
|
|
|
WordEmbeddings('pl'),
|
|
|
|
FlairEmbeddings('pl-forward'),
|
|
|
|
FlairEmbeddings('pl-backward'),
|
|
|
|
CharacterEmbeddings(),
|
|
|
|
]
|
|
|
|
|
|
|
|
embeddings = StackedEmbeddings(embeddings=embedding_types)
|
|
|
|
tagger = SequenceTagger(hidden_size=256, embeddings=embeddings, tag_dictionary=label_dictionary, tag_type=label_type, use_crf=True, tag_format="BIO")
|
|
|
|
|
|
|
|
frame_trainer = ModelTrainer(tagger, corpus)
|
|
|
|
frame_trainer.train(f'{label_type}-model', learning_rate=0.1, mini_batch_size=16, max_epochs=75, train_with_dev=False)
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
train_model("frame")
|
2024-05-22 23:45:33 +02:00
|
|
|
train_model('slot', field_parsers={'slot': nolabel2o})
|