chat-restaruacja/nlu_train.py
2024-05-22 23:45:33 +02:00

42 lines
1.8 KiB
Python

from conllu import parse_incr
from flair.data import Corpus
from flair.embeddings import StackedEmbeddings
from flair.embeddings import WordEmbeddings
from flair.embeddings import CharacterEmbeddings
from flair.embeddings import FlairEmbeddings
from flair.models import SequenceTagger
from flair.trainers import ModelTrainer
from src.utils.nlu_utils import conllu2flair, nolabel2o
import torch
if torch.cuda.is_available():
torch.backends.cudnn.enabled = False
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
def train_model(label_type, field_parsers = {}):
with open('data/train_dialog.conllu', encoding='utf-8') as f:
trainset = list(parse_incr(f, fields=['id', 'form', 'frame', 'slot'], field_parsers=field_parsers))
with open('data/test_dialog_46.conllu', encoding='utf-8') as f:
testset = list(parse_incr(f, fields=['id', 'form', 'frame', 'slot'], field_parsers=field_parsers))
corpus = Corpus(train=conllu2flair(trainset, label_type), test=conllu2flair(testset, label_type))
label_dictionary = corpus.make_label_dictionary(label_type=label_type)
embedding_types = [
WordEmbeddings('pl'),
FlairEmbeddings('pl-forward'),
FlairEmbeddings('pl-backward'),
CharacterEmbeddings(),
]
embeddings = StackedEmbeddings(embeddings=embedding_types)
tagger = SequenceTagger(hidden_size=256, embeddings=embeddings, tag_dictionary=label_dictionary, tag_type=label_type, use_crf=True, tag_format="BIO")
frame_trainer = ModelTrainer(tagger, corpus)
frame_trainer.train(f'{label_type}-model', learning_rate=0.1, mini_batch_size=16, max_epochs=75, train_with_dev=False)
if __name__ == '__main__':
train_model("frame")
train_model('slot', field_parsers={'slot': nolabel2o})