Update 'NLU_lab_7-8/main.py'

This commit is contained in:
Kacper Dudzic 2022-05-02 17:49:46 +02:00
parent 8b9d005112
commit 41530dae94

View File

@ -1,85 +1,85 @@
from conllu import parse_incr from conllu import parse_incr
from flair.data import Corpus, Sentence, Token from flair.data import Corpus, Sentence, Token
from flair.datasets import SentenceDataset from flair.datasets import SentenceDataset
from flair.embeddings import StackedEmbeddings from flair.embeddings import StackedEmbeddings
from flair.embeddings import WordEmbeddings from flair.embeddings import WordEmbeddings
from flair.embeddings import CharacterEmbeddings from flair.embeddings import CharacterEmbeddings
from flair.embeddings import FlairEmbeddings from flair.embeddings import FlairEmbeddings
from flair.models import SequenceTagger from flair.models import SequenceTagger
from flair.trainers import ModelTrainer from flair.trainers import ModelTrainer
import random import random
import torch import torch
from tabulate import tabulate from tabulate import tabulate
fields = ['id', 'form', 'frame', 'slot'] fields = ['id', 'form', 'frame', 'slot']
def nolabel2o(line, i): def nolabel2o(line, i):
return 'O' if line[i] == 'NoLabel' else line[i] return 'O' if line[i] == 'NoLabel' else line[i]
def conllu2flair(sentences, label=None): def conllu2flair(sentences, label=None):
fsentences = [] fsentences = []
for sentence in sentences: for sentence in sentences:
fsentence = Sentence() fsentence = Sentence()
for token in sentence: for token in sentence:
ftoken = Token(token['form']) ftoken = Token(token['form'])
if label: if label:
ftoken.add_tag(label, token[label]) ftoken.add_tag(label, token[label])
fsentence.add_token(ftoken) fsentence.add_token(ftoken)
fsentences.append(fsentence) fsentences.append(fsentence)
return SentenceDataset(fsentences) return SentenceDataset(fsentences)
def predict(model, sentence): def predict(model, sentence):
csentence = [{'form': word} for word in sentence] csentence = [{'form': word} for word in sentence]
fsentence = conllu2flair([csentence])[0] fsentence = conllu2flair([csentence])[0]
model.predict(fsentence) model.predict(fsentence)
return [(token, ftoken.get_tag('slot').value) for token, ftoken in zip(sentence, fsentence)] return [(token, ftoken.get_tag('slot').value) for token, ftoken in zip(sentence, fsentence)]
with open('train-pl-all.conllu', encoding='utf-8') as trainfile: with open('train-pl-full.conllu', encoding='utf-8') as trainfile:
trainset = list(parse_incr(trainfile, fields=fields, field_parsers={'slot': nolabel2o})) trainset = list(parse_incr(trainfile, fields=fields, field_parsers={'slot': nolabel2o}))
with open('test-pl-all.conllu', encoding='utf-8') as testfile: with open('test-pl-full.conllu', encoding='utf-8') as testfile:
testset = list(parse_incr(testfile, fields=fields, field_parsers={'slot': nolabel2o})) testset = list(parse_incr(testfile, fields=fields, field_parsers={'slot': nolabel2o}))
random.seed(42) random.seed(42)
torch.manual_seed(42) torch.manual_seed(42)
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.manual_seed(0) torch.cuda.manual_seed(0)
torch.cuda.manual_seed_all(0) torch.cuda.manual_seed_all(0)
torch.backends.cudnn.enabled = False torch.backends.cudnn.enabled = False
torch.backends.cudnn.benchmark = False torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True torch.backends.cudnn.deterministic = True
corpus = Corpus(train=conllu2flair(trainset, 'slot'), test=conllu2flair(testset, 'slot')) corpus = Corpus(train=conllu2flair(trainset, 'slot'), test=conllu2flair(testset, 'slot'))
tag_dictionary = corpus.make_tag_dictionary(tag_type='slot') tag_dictionary = corpus.make_tag_dictionary(tag_type='slot')
embedding_types = [ embedding_types = [
WordEmbeddings('pl'), WordEmbeddings('pl'),
FlairEmbeddings('pl-forward'), FlairEmbeddings('pl-forward'),
FlairEmbeddings('pl-backward'), FlairEmbeddings('pl-backward'),
CharacterEmbeddings(), CharacterEmbeddings(),
] ]
embeddings = StackedEmbeddings(embeddings=embedding_types) embeddings = StackedEmbeddings(embeddings=embedding_types)
tagger = SequenceTagger(hidden_size=256, embeddings=embeddings, tagger = SequenceTagger(hidden_size=256, embeddings=embeddings,
tag_dictionary=tag_dictionary, tag_dictionary=tag_dictionary,
tag_type='slot', use_crf=True) tag_type='slot', use_crf=True)
""" """
trainer = ModelTrainer(tagger, corpus) trainer = ModelTrainer(tagger, corpus)
trainer.train('slot-model-pl', trainer.train('slot-model-pl',
learning_rate=0.1, learning_rate=0.1,
mini_batch_size=32, mini_batch_size=32,
max_epochs=10, max_epochs=10,
train_with_dev=True) train_with_dev=True)
""" """
try: try:
model = SequenceTagger.load('slot-model-pl/best-model.pt') model = SequenceTagger.load('slot-model-pl/best-model.pt')
except: except:
model = SequenceTagger.load('slot-model-pl/final-model.pt') model = SequenceTagger.load('slot-model-pl/final-model.pt')
print(tabulate(predict(model, 'Jeden bilet na imię Jan Kowalski na film Batman'.split()))) print(tabulate(predict(model, 'Jeden bilet na imię Jan Kowalski na film Batman'.split())))