chat-restaruacja/nlu_train.py

42 lines
1.8 KiB
Python
Raw Normal View History

2024-05-10 01:25:13 +02:00
from conllu import parse_incr
from flair.data import Corpus
from flair.embeddings import StackedEmbeddings
from flair.embeddings import WordEmbeddings
from flair.embeddings import CharacterEmbeddings
from flair.embeddings import FlairEmbeddings
from flair.models import SequenceTagger
from flair.trainers import ModelTrainer
2024-05-22 23:45:33 +02:00
from src.utils.nlu_utils import conllu2flair, nolabel2o
2024-05-10 01:25:13 +02:00
import torch
if torch.cuda.is_available():
torch.backends.cudnn.enabled = False
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
def train_model(label_type, field_parsers = {}):
with open('data/train_dialog.conllu', encoding='utf-8') as f:
trainset = list(parse_incr(f, fields=['id', 'form', 'frame', 'slot'], field_parsers=field_parsers))
with open('data/test_dialog_46.conllu', encoding='utf-8') as f:
testset = list(parse_incr(f, fields=['id', 'form', 'frame', 'slot'], field_parsers=field_parsers))
corpus = Corpus(train=conllu2flair(trainset, label_type), test=conllu2flair(testset, label_type))
label_dictionary = corpus.make_label_dictionary(label_type=label_type)
embedding_types = [
WordEmbeddings('pl'),
FlairEmbeddings('pl-forward'),
FlairEmbeddings('pl-backward'),
CharacterEmbeddings(),
]
embeddings = StackedEmbeddings(embeddings=embedding_types)
tagger = SequenceTagger(hidden_size=256, embeddings=embeddings, tag_dictionary=label_dictionary, tag_type=label_type, use_crf=True, tag_format="BIO")
frame_trainer = ModelTrainer(tagger, corpus)
frame_trainer.train(f'{label_type}-model', learning_rate=0.1, mini_batch_size=16, max_epochs=75, train_with_dev=False)
if __name__ == '__main__':
train_model("frame")
2024-05-22 23:45:33 +02:00
train_model('slot', field_parsers={'slot': nolabel2o})