46 lines
1.7 KiB
Python
46 lines
1.7 KiB
Python
|
from conllu import parse_incr
|
||
|
from flair.data import Corpus
|
||
|
from flair.embeddings import StackedEmbeddings
|
||
|
from flair.embeddings import WordEmbeddings
|
||
|
from flair.embeddings import CharacterEmbeddings
|
||
|
from flair.embeddings import FlairEmbeddings
|
||
|
from flair.models import SequenceTagger
|
||
|
from flair.trainers import ModelTrainer
|
||
|
from nlu_utils import conllu2flair, nolabel2o
|
||
|
|
||
|
import random
|
||
|
import torch
|
||
|
random.seed(42)
|
||
|
torch.manual_seed(42)
|
||
|
|
||
|
if torch.cuda.is_available():
|
||
|
torch.cuda.manual_seed(0)
|
||
|
torch.cuda.manual_seed_all(0)
|
||
|
torch.backends.cudnn.enabled = False
|
||
|
torch.backends.cudnn.benchmark = False
|
||
|
torch.backends.cudnn.deterministic = True
|
||
|
|
||
|
|
||
|
def train_model(label_type, field_parsers = {}):
|
||
|
with open('data/train_dialog.conllu', encoding='utf-8') as trainfile:
|
||
|
trainset = list(parse_incr(trainfile, fields=['id', 'form', 'frame', 'slot'], field_parsers=field_parsers))
|
||
|
|
||
|
corpus = Corpus(train=conllu2flair(trainset, label_type), test=conllu2flair(trainset, label_type))
|
||
|
label_dictionary = corpus.make_label_dictionary(label_type=label_type)
|
||
|
|
||
|
embedding_types = [
|
||
|
WordEmbeddings('pl'),
|
||
|
FlairEmbeddings('pl-forward'),
|
||
|
FlairEmbeddings('pl-backward'),
|
||
|
CharacterEmbeddings(),
|
||
|
]
|
||
|
|
||
|
embeddings = StackedEmbeddings(embeddings=embedding_types)
|
||
|
tagger = SequenceTagger(hidden_size=256, embeddings=embeddings, tag_dictionary=label_dictionary, tag_type=label_type, use_crf=True, tag_format="BIO")
|
||
|
|
||
|
frame_trainer = ModelTrainer(tagger, corpus)
|
||
|
frame_trainer.train(f'{label_type}-model', learning_rate=0.1, mini_batch_size=32, max_epochs=75, train_with_dev=False)
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
train_model("frame")
|
||
|
train_model('slot', field_parsers={'slot': nolabel2o})
|