55 KiB
Systemy Dialogowe
8. Parsing semantyczny z wykorzystaniem technik uczenia maszynowego [laboratoria]
Marek Kubis (2021)
Parsing semantyczny z wykorzystaniem technik uczenia maszynowego
Wprowadzenie
Problem wykrywania slotów i ich wartości w wypowiedziach użytkownika można sformułować jako zadanie polegające na przewidywaniu dla poszczególnych słów etykiet wskazujących na to czy i do jakiego slotu dane słowo należy.
chciałbym zarezerwować stolik na jutro**/day** na godzinę dwunastą**/hour** czterdzieści**/hour** pięć**/hour** na pięć**/size** osób
Granice slotów oznacza się korzystając z wybranego schematu etykietowania.
Schemat IOB
Prefix | Znaczenie |
---|---|
I | wnętrze slotu (inside) |
O | poza slotem (outside) |
B | początek slotu (beginning) |
chciałbym zarezerwować stolik na jutro**/B-day** na godzinę dwunastą**/B-hour** czterdzieści**/I-hour** pięć**/I-hour** na pięć**/B-size** osób
Schemat IOBES
Prefix | Znaczenie |
---|---|
I | wnętrze slotu (inside) |
O | poza slotem (outside) |
B | początek slotu (beginning) |
E | koniec slotu (ending) |
S | pojedyncze słowo (single) |
chciałbym zarezerwować stolik na jutro**/S-day** na godzinę dwunastą**/B-hour** czterdzieści**/I-hour** pięć**/E-hour** na pięć**/S-size** osób
Jeżeli dla tak sformułowanego zadania przygotujemy zbiór danych złożony z wypowiedzi użytkownika z oznaczonymi slotami (tzw. _zbiór uczący), to możemy zastosować techniki (nadzorowanego) uczenia maszynowego w celu zbudowania modelu annotującego wypowiedzi użytkownika etykietami slotów.
Do zbudowania takiego modelu można wykorzystać między innymi:
warunkowe pola losowe (Lafferty i in.; 2001),
rekurencyjne sieci neuronowe, np. sieci LSTM (Hochreiter i Schmidhuber; 1997),
transformery (Vaswani i in., 2017).
Przykład
Skorzystamy ze zbioru danych przygotowanego przez Schustera (2019).
Zbiór ten gromadzi wypowiedzi w trzech językach opisane slotami dla dwunastu ram należących do trzech dziedzin Alarm
, Reminder
oraz Weather
. Dane wczytamy korzystając z biblioteki conllu.
from conllu import parse_incr
fields = ['id', 'form', 'frame', 'slot']
def nolabel2o(line, i):
return 'O' if line[i] == 'NoLabel' else line[i]
# pathTrain = '../tasks/zad8/en/train-en.conllu'
# pathTest = '../tasks/zad8/en/test-en.conllu'
pathTrain = '../tasks/zad8/pl/train.conllu'
pathTest = '../tasks/zad8/pl/test.conllu'
with open(pathTrain, encoding="UTF-8") as trainfile:
i=0
for line in trainfile:
print(line)
i+=1
if i==15: break
trainset = list(parse_incr(trainfile, fields=fields, field_parsers={'slot': nolabel2o}))
with open(pathTest, encoding="UTF-8") as testfile:
testset = list(parse_incr(testfile, fields=fields, field_parsers={'slot': nolabel2o}))
# text: halo # intent: hello # slots: 1 halo hello NoLabel # text: chaciałbym pójść na premierę filmu jakie premiery są w tym tygodniu # intent: reqmore # slots: 1 chaciałbym reqmore NoLabel 2 pójść reqmore NoLabel 3 na reqmore NoLabel 4 premierę reqmore NoLabel 5 filmu reqmore NoLabel 6 jakie reqmore B-goal 7 premiery reqmore I-goal
Zobaczmy kilka przykładowych wypowiedzi z tego zbioru.
from tabulate import tabulate
tabulate(trainset[1], tablefmt='html')
1 | wybieram | inform | O |
2 | batmana | inform | B-title |
tabulate(trainset[16], tablefmt='html')
1 | chcę | inform | O |
2 | zarezerwować | inform | B-goal |
3 | bilety | inform | O |
tabulate(trainset[20], tablefmt='html')
1 | chciałbym | inform | O |
2 | anulować | inform | O |
3 | rezerwację | inform | O |
4 | biletu | inform | O |
Budując model skorzystamy z architektury opartej o rekurencyjne sieci neuronowe zaimplementowanej w bibliotece flair (Akbik i in. 2018).
from flair.data import Corpus, Sentence, Token
from flair.datasets import SentenceDataset
from flair.embeddings import StackedEmbeddings
from flair.embeddings import WordEmbeddings
from flair.embeddings import CharacterEmbeddings
from flair.embeddings import FlairEmbeddings
from flair.models import SequenceTagger
from flair.trainers import ModelTrainer
# determinizacja obliczeń
import random
import torch
random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
torch.cuda.manual_seed(0)
torch.cuda.manual_seed_all(0)
torch.backends.cudnn.enabled = False
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
Dane skonwertujemy do formatu wykorzystywanego przez flair
, korzystając z następującej funkcji.
def conllu2flair(sentences, label=None):
fsentences = []
for sentence in sentences:
fsentence = Sentence()
for token in sentence:
ftoken = Token(token['form'])
if label:
ftoken.add_tag(label, token[label])
fsentence.add_token(ftoken)
fsentences.append(fsentence)
return SentenceDataset(fsentences)
corpus = Corpus(train=conllu2flair(trainset, 'slot'), test=conllu2flair(testset, 'slot'))
print(corpus)
tag_dictionary = corpus.make_tag_dictionary(tag_type='slot')
print(tag_dictionary)
Corpus: 345 train + 38 dev + 32 test sentences Dictionary with 20 tags: <unk>, O, B-interval, I-interval, B-title, B-date, I-date, B-time, B-quantity, B-area, I-area, B-goal, I-goal, I-title, I-time, I-quantity, B-seats, I-seats, <START>, <STOP>
Nasz model będzie wykorzystywał wektorowe reprezentacje słów (zob. Word Embeddings).
embedding_types = [
WordEmbeddings('pl'),
FlairEmbeddings('polish-forward'),
FlairEmbeddings('polish-backward'),
CharacterEmbeddings(),
]
embeddings = StackedEmbeddings(embeddings=embedding_types)
tagger = SequenceTagger(hidden_size=256, embeddings=embeddings,
tag_dictionary=tag_dictionary,
tag_type='slot', use_crf=True)
Zobaczmy jak wygląda architektura sieci neuronowej, która będzie odpowiedzialna za przewidywanie slotów w wypowiedziach.
print(tagger)
SequenceTagger( (embeddings): StackedEmbeddings( (list_embedding_0): WordEmbeddings('pl') (list_embedding_1): FlairEmbeddings( (lm): LanguageModel( (drop): Dropout(p=0.25, inplace=False) (encoder): Embedding(1602, 100) (rnn): LSTM(100, 2048) (decoder): Linear(in_features=2048, out_features=1602, bias=True) ) ) (list_embedding_2): FlairEmbeddings( (lm): LanguageModel( (drop): Dropout(p=0.25, inplace=False) (encoder): Embedding(1602, 100) (rnn): LSTM(100, 2048) (decoder): Linear(in_features=2048, out_features=1602, bias=True) ) ) (list_embedding_3): CharacterEmbeddings( (char_embedding): Embedding(275, 25) (char_rnn): LSTM(25, 25, bidirectional=True) ) ) (word_dropout): WordDropout(p=0.05) (locked_dropout): LockedDropout(p=0.5) (embedding2nn): Linear(in_features=4446, out_features=4446, bias=True) (rnn): LSTM(4446, 256, batch_first=True, bidirectional=True) (linear): Linear(in_features=512, out_features=20, bias=True) (beta): 1.0 (weights): None (weight_tensor) None )
Wykonamy dziesięć iteracji (epok) uczenia a wynikowy model zapiszemy w katalogu slot-model
.
trainer = ModelTrainer(tagger, corpus)
trainer.train('slot-model',
learning_rate=0.1,
mini_batch_size=32,
max_epochs=10,
train_with_dev=False)
2022-05-01 12:13:39,609 ---------------------------------------------------------------------------------------------------- 2022-05-01 12:13:39,610 Model: "SequenceTagger( (embeddings): StackedEmbeddings( (list_embedding_0): WordEmbeddings('pl') (list_embedding_1): FlairEmbeddings( (lm): LanguageModel( (drop): Dropout(p=0.25, inplace=False) (encoder): Embedding(1602, 100) (rnn): LSTM(100, 2048) (decoder): Linear(in_features=2048, out_features=1602, bias=True) ) ) (list_embedding_2): FlairEmbeddings( (lm): LanguageModel( (drop): Dropout(p=0.25, inplace=False) (encoder): Embedding(1602, 100) (rnn): LSTM(100, 2048) (decoder): Linear(in_features=2048, out_features=1602, bias=True) ) ) (list_embedding_3): CharacterEmbeddings( (char_embedding): Embedding(275, 25) (char_rnn): LSTM(25, 25, bidirectional=True) ) ) (word_dropout): WordDropout(p=0.05) (locked_dropout): LockedDropout(p=0.5) (embedding2nn): Linear(in_features=4446, out_features=4446, bias=True) (rnn): LSTM(4446, 256, batch_first=True, bidirectional=True) (linear): Linear(in_features=512, out_features=20, bias=True) (beta): 1.0 (weights): None (weight_tensor) None )" 2022-05-01 12:13:39,611 ---------------------------------------------------------------------------------------------------- 2022-05-01 12:13:39,611 Corpus: "Corpus: 345 train + 38 dev + 32 test sentences" 2022-05-01 12:13:39,612 ---------------------------------------------------------------------------------------------------- 2022-05-01 12:13:39,613 Parameters: 2022-05-01 12:13:39,614 - learning_rate: "0.1" 2022-05-01 12:13:39,614 - mini_batch_size: "32" 2022-05-01 12:13:39,615 - patience: "3" 2022-05-01 12:13:39,616 - anneal_factor: "0.5" 2022-05-01 12:13:39,616 - max_epochs: "10" 2022-05-01 12:13:39,616 - shuffle: "True" 2022-05-01 12:13:39,617 - train_with_dev: "False" 2022-05-01 12:13:39,618 - batch_growth_annealing: "False" 2022-05-01 12:13:39,618 ---------------------------------------------------------------------------------------------------- 2022-05-01 12:13:39,619 Model training base path: "slot-model" 2022-05-01 12:13:39,620 ---------------------------------------------------------------------------------------------------- 2022-05-01 12:13:39,620 Device: cpu 2022-05-01 12:13:39,621 ---------------------------------------------------------------------------------------------------- 2022-05-01 12:13:39,621 Embeddings storage mode: cpu 2022-05-01 12:13:39,623 ---------------------------------------------------------------------------------------------------- 2022-05-01 12:13:42,490 epoch 1 - iter 1/11 - loss 9.59000492 - samples/sec: 11.17 - lr: 0.100000 2022-05-01 12:13:44,150 epoch 1 - iter 2/11 - loss 9.31767702 - samples/sec: 19.29 - lr: 0.100000 2022-05-01 12:13:45,968 epoch 1 - iter 3/11 - loss 8.70617644 - samples/sec: 17.61 - lr: 0.100000 2022-05-01 12:13:47,791 epoch 1 - iter 4/11 - loss 8.11678410 - samples/sec: 17.57 - lr: 0.100000 2022-05-01 12:13:49,815 epoch 1 - iter 5/11 - loss 7.65581417 - samples/sec: 15.82 - lr: 0.100000 2022-05-01 12:13:52,296 epoch 1 - iter 6/11 - loss 7.27475810 - samples/sec: 12.90 - lr: 0.100000 2022-05-01 12:13:54,454 epoch 1 - iter 7/11 - loss 6.95693064 - samples/sec: 14.84 - lr: 0.100000 2022-05-01 12:13:56,845 epoch 1 - iter 8/11 - loss 6.61199290 - samples/sec: 13.39 - lr: 0.100000 2022-05-01 12:13:59,195 epoch 1 - iter 9/11 - loss 6.58955601 - samples/sec: 13.63 - lr: 0.100000 2022-05-01 12:14:01,065 epoch 1 - iter 10/11 - loss 6.63135071 - samples/sec: 17.11 - lr: 0.100000 2022-05-01 12:14:02,415 epoch 1 - iter 11/11 - loss 6.52558366 - samples/sec: 23.72 - lr: 0.100000 2022-05-01 12:14:02,416 ---------------------------------------------------------------------------------------------------- 2022-05-01 12:14:02,417 EPOCH 1 done: loss 6.5256 - lr 0.1000000 2022-05-01 12:14:05,139 DEV : loss 8.419286727905273 - score 0.0 2022-05-01 12:14:05,141 BAD EPOCHS (no improvement): 0 saving best model 2022-05-01 12:14:15,906 ---------------------------------------------------------------------------------------------------- 2022-05-01 12:14:16,782 epoch 2 - iter 1/11 - loss 7.61237478 - samples/sec: 40.25 - lr: 0.100000 2022-05-01 12:14:17,253 epoch 2 - iter 2/11 - loss 7.02023911 - samples/sec: 68.09 - lr: 0.100000 2022-05-01 12:14:17,744 epoch 2 - iter 3/11 - loss 6.25125138 - samples/sec: 65.31 - lr: 0.100000 2022-05-01 12:14:18,282 epoch 2 - iter 4/11 - loss 5.91574061 - samples/sec: 59.59 - lr: 0.100000 2022-05-01 12:14:18,742 epoch 2 - iter 5/11 - loss 5.80905600 - samples/sec: 69.87 - lr: 0.100000 2022-05-01 12:14:19,262 epoch 2 - iter 6/11 - loss 5.51969266 - samples/sec: 61.66 - lr: 0.100000 2022-05-01 12:14:19,753 epoch 2 - iter 7/11 - loss 5.34836953 - samples/sec: 65.31 - lr: 0.100000 2022-05-01 12:14:20,267 epoch 2 - iter 8/11 - loss 5.33710295 - samples/sec: 62.38 - lr: 0.100000 2022-05-01 12:14:20,750 epoch 2 - iter 9/11 - loss 5.28061861 - samples/sec: 66.32 - lr: 0.100000 2022-05-01 12:14:21,379 epoch 2 - iter 10/11 - loss 5.20552692 - samples/sec: 50.95 - lr: 0.100000 2022-05-01 12:14:21,922 epoch 2 - iter 11/11 - loss 5.26294283 - samples/sec: 59.03 - lr: 0.100000 2022-05-01 12:14:21,923 ---------------------------------------------------------------------------------------------------- 2022-05-01 12:14:21,924 EPOCH 2 done: loss 5.2629 - lr 0.1000000 2022-05-01 12:14:22,145 DEV : loss 7.168168544769287 - score 0.0645 2022-05-01 12:14:22,149 BAD EPOCHS (no improvement): 0 saving best model 2022-05-01 12:14:27,939 ---------------------------------------------------------------------------------------------------- 2022-05-01 12:14:28,495 epoch 3 - iter 1/11 - loss 3.70659065 - samples/sec: 57.56 - lr: 0.100000 2022-05-01 12:14:29,038 epoch 3 - iter 2/11 - loss 4.21530080 - samples/sec: 59.04 - lr: 0.100000 2022-05-01 12:14:29,607 epoch 3 - iter 3/11 - loss 4.40864404 - samples/sec: 56.37 - lr: 0.100000 2022-05-01 12:14:30,171 epoch 3 - iter 4/11 - loss 4.69527233 - samples/sec: 56.93 - lr: 0.100000 2022-05-01 12:14:30,587 epoch 3 - iter 5/11 - loss 4.43719640 - samples/sec: 77.11 - lr: 0.100000 2022-05-01 12:14:31,075 epoch 3 - iter 6/11 - loss 4.55344125 - samples/sec: 65.71 - lr: 0.100000 2022-05-01 12:14:31,625 epoch 3 - iter 7/11 - loss 4.77397609 - samples/sec: 58.34 - lr: 0.100000 2022-05-01 12:14:32,143 epoch 3 - iter 8/11 - loss 4.61572361 - samples/sec: 61.89 - lr: 0.100000 2022-05-01 12:14:32,703 epoch 3 - iter 9/11 - loss 4.60090372 - samples/sec: 57.24 - lr: 0.100000 2022-05-01 12:14:33,404 epoch 3 - iter 10/11 - loss 4.70502276 - samples/sec: 45.69 - lr: 0.100000 2022-05-01 12:14:33,839 epoch 3 - iter 11/11 - loss 4.76321775 - samples/sec: 73.73 - lr: 0.100000 2022-05-01 12:14:33,840 ---------------------------------------------------------------------------------------------------- 2022-05-01 12:14:33,840 EPOCH 3 done: loss 4.7632 - lr 0.1000000 2022-05-01 12:14:33,992 DEV : loss 7.209894180297852 - score 0.0 2022-05-01 12:14:33,993 BAD EPOCHS (no improvement): 1 2022-05-01 12:14:33,994 ---------------------------------------------------------------------------------------------------- 2022-05-01 12:14:34,556 epoch 4 - iter 1/11 - loss 5.55247641 - samples/sec: 57.04 - lr: 0.100000 2022-05-01 12:14:35,078 epoch 4 - iter 2/11 - loss 5.08158088 - samples/sec: 61.42 - lr: 0.100000 2022-05-01 12:14:35,643 epoch 4 - iter 3/11 - loss 4.69475476 - samples/sec: 56.73 - lr: 0.100000 2022-05-01 12:14:36,270 epoch 4 - iter 4/11 - loss 4.78649628 - samples/sec: 51.16 - lr: 0.100000 2022-05-01 12:14:36,806 epoch 4 - iter 5/11 - loss 4.62873497 - samples/sec: 59.93 - lr: 0.100000 2022-05-01 12:14:37,419 epoch 4 - iter 6/11 - loss 4.70938087 - samples/sec: 52.29 - lr: 0.100000 2022-05-01 12:14:38,068 epoch 4 - iter 7/11 - loss 4.50588363 - samples/sec: 49.46 - lr: 0.100000 2022-05-01 12:14:38,581 epoch 4 - iter 8/11 - loss 4.36334288 - samples/sec: 62.50 - lr: 0.100000 2022-05-01 12:14:39,140 epoch 4 - iter 9/11 - loss 4.36617618 - samples/sec: 57.45 - lr: 0.100000 2022-05-01 12:14:39,780 epoch 4 - iter 10/11 - loss 4.37847199 - samples/sec: 50.16 - lr: 0.100000 2022-05-01 12:14:40,321 epoch 4 - iter 11/11 - loss 4.26116128 - samples/sec: 59.18 - lr: 0.100000 2022-05-01 12:14:40,323 ---------------------------------------------------------------------------------------------------- 2022-05-01 12:14:40,324 EPOCH 4 done: loss 4.2612 - lr 0.1000000 2022-05-01 12:14:40,544 DEV : loss 5.882441997528076 - score 0.1714 2022-05-01 12:14:40,546 BAD EPOCHS (no improvement): 0 saving best model 2022-05-01 12:14:46,159 ---------------------------------------------------------------------------------------------------- 2022-05-01 12:14:46,709 epoch 5 - iter 1/11 - loss 3.86370564 - samples/sec: 58.29 - lr: 0.100000 2022-05-01 12:14:47,349 epoch 5 - iter 2/11 - loss 3.80554891 - samples/sec: 50.08 - lr: 0.100000 2022-05-01 12:14:47,857 epoch 5 - iter 3/11 - loss 3.34506067 - samples/sec: 63.11 - lr: 0.100000 2022-05-01 12:14:48,579 epoch 5 - iter 4/11 - loss 3.88535106 - samples/sec: 44.38 - lr: 0.100000 2022-05-01 12:14:49,170 epoch 5 - iter 5/11 - loss 3.81894360 - samples/sec: 54.28 - lr: 0.100000 2022-05-01 12:14:49,708 epoch 5 - iter 6/11 - loss 4.18858314 - samples/sec: 59.53 - lr: 0.100000 2022-05-01 12:14:50,171 epoch 5 - iter 7/11 - loss 4.13974752 - samples/sec: 69.26 - lr: 0.100000 2022-05-01 12:14:50,593 epoch 5 - iter 8/11 - loss 4.01002905 - samples/sec: 75.98 - lr: 0.100000 2022-05-01 12:14:51,062 epoch 5 - iter 9/11 - loss 3.97078644 - samples/sec: 68.52 - lr: 0.100000 2022-05-01 12:14:51,508 epoch 5 - iter 10/11 - loss 3.94409857 - samples/sec: 71.91 - lr: 0.100000 2022-05-01 12:14:51,960 epoch 5 - iter 11/11 - loss 3.80738796 - samples/sec: 70.95 - lr: 0.100000 2022-05-01 12:14:51,961 ---------------------------------------------------------------------------------------------------- 2022-05-01 12:14:51,963 EPOCH 5 done: loss 3.8074 - lr 0.1000000 2022-05-01 12:14:52,103 DEV : loss 5.224854469299316 - score 0.1667 2022-05-01 12:14:52,105 BAD EPOCHS (no improvement): 1 2022-05-01 12:14:52,106 ---------------------------------------------------------------------------------------------------- 2022-05-01 12:14:52,616 epoch 6 - iter 1/11 - loss 3.51282573 - samples/sec: 62.91 - lr: 0.100000 2022-05-01 12:14:53,100 epoch 6 - iter 2/11 - loss 3.41601551 - samples/sec: 66.25 - lr: 0.100000 2022-05-01 12:14:53,513 epoch 6 - iter 3/11 - loss 3.08380787 - samples/sec: 77.76 - lr: 0.100000 2022-05-01 12:14:55,121 epoch 6 - iter 4/11 - loss 3.21056002 - samples/sec: 64.71 - lr: 0.100000 2022-05-01 12:14:55,665 epoch 6 - iter 5/11 - loss 3.30184879 - samples/sec: 58.88 - lr: 0.100000 2022-05-01 12:14:56,160 epoch 6 - iter 6/11 - loss 3.20993070 - samples/sec: 64.91 - lr: 0.100000 2022-05-01 12:14:56,670 epoch 6 - iter 7/11 - loss 3.14396119 - samples/sec: 62.91 - lr: 0.100000 2022-05-01 12:14:57,329 epoch 6 - iter 8/11 - loss 3.24591878 - samples/sec: 48.63 - lr: 0.100000 2022-05-01 12:14:57,958 epoch 6 - iter 9/11 - loss 3.31877112 - samples/sec: 51.03 - lr: 0.100000 2022-05-01 12:14:58,527 epoch 6 - iter 10/11 - loss 3.33475649 - samples/sec: 56.34 - lr: 0.100000 2022-05-01 12:14:58,989 epoch 6 - iter 11/11 - loss 3.23232636 - samples/sec: 69.41 - lr: 0.100000 2022-05-01 12:14:58,991 ---------------------------------------------------------------------------------------------------- 2022-05-01 12:14:58,991 EPOCH 6 done: loss 3.2323 - lr 0.1000000 2022-05-01 12:14:59,178 DEV : loss 4.557621002197266 - score 0.2381 2022-05-01 12:14:59,180 BAD EPOCHS (no improvement): 0 saving best model 2022-05-01 12:15:25,844 ---------------------------------------------------------------------------------------------------- 2022-05-01 12:15:26,423 epoch 7 - iter 1/11 - loss 2.71161938 - samples/sec: 55.36 - lr: 0.100000 2022-05-01 12:15:26,886 epoch 7 - iter 2/11 - loss 2.50157821 - samples/sec: 69.26 - lr: 0.100000 2022-05-01 12:15:27,347 epoch 7 - iter 3/11 - loss 2.78014056 - samples/sec: 69.56 - lr: 0.100000 2022-05-01 12:15:27,853 epoch 7 - iter 4/11 - loss 2.82983196 - samples/sec: 63.36 - lr: 0.100000 2022-05-01 12:15:28,393 epoch 7 - iter 5/11 - loss 2.84246483 - samples/sec: 59.37 - lr: 0.100000 2022-05-01 12:15:28,847 epoch 7 - iter 6/11 - loss 2.89787177 - samples/sec: 70.64 - lr: 0.100000 2022-05-01 12:15:29,338 epoch 7 - iter 7/11 - loss 2.74564961 - samples/sec: 65.30 - lr: 0.100000 2022-05-01 12:15:29,813 epoch 7 - iter 8/11 - loss 2.79853699 - samples/sec: 67.58 - lr: 0.100000 2022-05-01 12:15:30,364 epoch 7 - iter 9/11 - loss 2.89167126 - samples/sec: 58.18 - lr: 0.100000 2022-05-01 12:15:30,834 epoch 7 - iter 10/11 - loss 2.86527851 - samples/sec: 68.22 - lr: 0.100000 2022-05-01 12:15:31,296 epoch 7 - iter 11/11 - loss 2.82858575 - samples/sec: 69.41 - lr: 0.100000 2022-05-01 12:15:31,297 ---------------------------------------------------------------------------------------------------- 2022-05-01 12:15:31,298 EPOCH 7 done: loss 2.8286 - lr 0.1000000 2022-05-01 12:15:31,462 DEV : loss 4.020608901977539 - score 0.3182 2022-05-01 12:15:31,463 BAD EPOCHS (no improvement): 0 saving best model 2022-05-01 12:15:38,431 ---------------------------------------------------------------------------------------------------- 2022-05-01 12:15:38,979 epoch 8 - iter 1/11 - loss 3.28806710 - samples/sec: 58.61 - lr: 0.100000 2022-05-01 12:15:39,534 epoch 8 - iter 2/11 - loss 2.72140074 - samples/sec: 57.76 - lr: 0.100000 2022-05-01 12:15:40,061 epoch 8 - iter 3/11 - loss 2.77740423 - samples/sec: 60.89 - lr: 0.100000 2022-05-01 12:15:40,541 epoch 8 - iter 4/11 - loss 2.51573136 - samples/sec: 66.72 - lr: 0.100000 2022-05-01 12:15:41,109 epoch 8 - iter 5/11 - loss 2.54271443 - samples/sec: 56.53 - lr: 0.100000 2022-05-01 12:15:41,537 epoch 8 - iter 6/11 - loss 2.47530021 - samples/sec: 75.12 - lr: 0.100000 2022-05-01 12:15:42,078 epoch 8 - iter 7/11 - loss 2.62978831 - samples/sec: 59.26 - lr: 0.100000 2022-05-01 12:15:42,506 epoch 8 - iter 8/11 - loss 2.62844713 - samples/sec: 74.84 - lr: 0.100000 2022-05-01 12:15:42,988 epoch 8 - iter 9/11 - loss 2.61604464 - samples/sec: 66.59 - lr: 0.100000 2022-05-01 12:15:43,471 epoch 8 - iter 10/11 - loss 2.62512223 - samples/sec: 66.39 - lr: 0.100000 2022-05-01 12:15:43,895 epoch 8 - iter 11/11 - loss 2.64045010 - samples/sec: 75.65 - lr: 0.100000 2022-05-01 12:15:43,896 ---------------------------------------------------------------------------------------------------- 2022-05-01 12:15:43,897 EPOCH 8 done: loss 2.6405 - lr 0.1000000 2022-05-01 12:15:44,036 DEV : loss 3.542769432067871 - score 0.3846 2022-05-01 12:15:44,038 BAD EPOCHS (no improvement): 0 saving best model 2022-05-01 12:15:51,672 ---------------------------------------------------------------------------------------------------- 2022-05-01 12:15:52,235 epoch 9 - iter 1/11 - loss 1.73337626 - samples/sec: 56.99 - lr: 0.100000 2022-05-01 12:15:52,801 epoch 9 - iter 2/11 - loss 2.09788013 - samples/sec: 56.74 - lr: 0.100000 2022-05-01 12:15:53,288 epoch 9 - iter 3/11 - loss 2.24861153 - samples/sec: 65.84 - lr: 0.100000 2022-05-01 12:15:53,735 epoch 9 - iter 4/11 - loss 2.42630130 - samples/sec: 71.75 - lr: 0.100000 2022-05-01 12:15:54,189 epoch 9 - iter 5/11 - loss 2.42454610 - samples/sec: 70.64 - lr: 0.100000 2022-05-01 12:15:54,720 epoch 9 - iter 6/11 - loss 2.39987107 - samples/sec: 60.38 - lr: 0.100000 2022-05-01 12:15:55,192 epoch 9 - iter 7/11 - loss 2.29154910 - samples/sec: 67.94 - lr: 0.100000 2022-05-01 12:15:55,632 epoch 9 - iter 8/11 - loss 2.22984707 - samples/sec: 73.06 - lr: 0.100000 2022-05-01 12:15:56,162 epoch 9 - iter 9/11 - loss 2.32317919 - samples/sec: 60.49 - lr: 0.100000 2022-05-01 12:15:56,559 epoch 9 - iter 10/11 - loss 2.24865967 - samples/sec: 80.81 - lr: 0.100000 2022-05-01 12:15:56,986 epoch 9 - iter 11/11 - loss 2.27327953 - samples/sec: 75.12 - lr: 0.100000 2022-05-01 12:15:56,988 ---------------------------------------------------------------------------------------------------- 2022-05-01 12:15:56,988 EPOCH 9 done: loss 2.2733 - lr 0.1000000 2022-05-01 12:15:57,130 DEV : loss 3.4634602069854736 - score 0.5517 2022-05-01 12:15:57,132 BAD EPOCHS (no improvement): 0 saving best model 2022-05-01 12:16:04,067 ---------------------------------------------------------------------------------------------------- 2022-05-01 12:16:04,643 epoch 10 - iter 1/11 - loss 2.22972107 - samples/sec: 55.65 - lr: 0.100000 2022-05-01 12:16:05,144 epoch 10 - iter 2/11 - loss 2.20346498 - samples/sec: 64.00 - lr: 0.100000 2022-05-01 12:16:05,576 epoch 10 - iter 3/11 - loss 2.07501336 - samples/sec: 74.24 - lr: 0.100000 2022-05-01 12:16:06,036 epoch 10 - iter 4/11 - loss 2.09982607 - samples/sec: 69.72 - lr: 0.100000 2022-05-01 12:16:06,508 epoch 10 - iter 5/11 - loss 2.08048103 - samples/sec: 67.94 - lr: 0.100000 2022-05-01 12:16:07,062 epoch 10 - iter 6/11 - loss 2.08074635 - samples/sec: 57.87 - lr: 0.100000 2022-05-01 12:16:07,590 epoch 10 - iter 7/11 - loss 2.07187140 - samples/sec: 60.84 - lr: 0.100000 2022-05-01 12:16:08,116 epoch 10 - iter 8/11 - loss 2.10148455 - samples/sec: 60.95 - lr: 0.100000 2022-05-01 12:16:08,563 epoch 10 - iter 9/11 - loss 2.06198527 - samples/sec: 71.74 - lr: 0.100000 2022-05-01 12:16:09,066 epoch 10 - iter 10/11 - loss 2.00194792 - samples/sec: 63.75 - lr: 0.100000 2022-05-01 12:16:09,486 epoch 10 - iter 11/11 - loss 2.00801701 - samples/sec: 76.37 - lr: 0.100000 2022-05-01 12:16:09,487 ---------------------------------------------------------------------------------------------------- 2022-05-01 12:16:09,488 EPOCH 10 done: loss 2.0080 - lr 0.1000000 2022-05-01 12:16:09,624 DEV : loss 3.1866908073425293 - score 0.4706 2022-05-01 12:16:09,625 BAD EPOCHS (no improvement): 1 2022-05-01 12:16:16,655 ---------------------------------------------------------------------------------------------------- 2022-05-01 12:16:16,656 Testing using best model ... 2022-05-01 12:16:16,676 loading file slot-model\best-model.pt 2022-05-01 12:16:22,739 0.4231 0.3056 0.3548 2022-05-01 12:16:22,740 Results: - F1-score (micro) 0.3548 - F1-score (macro) 0.2570 By class: area tp: 1 - fp: 1 - fn: 2 - precision: 0.5000 - recall: 0.3333 - f1-score: 0.4000 date tp: 0 - fp: 3 - fn: 3 - precision: 0.0000 - recall: 0.0000 - f1-score: 0.0000 goal tp: 2 - fp: 2 - fn: 8 - precision: 0.5000 - recall: 0.2000 - f1-score: 0.2857 interval tp: 0 - fp: 0 - fn: 1 - precision: 0.0000 - recall: 0.0000 - f1-score: 0.0000 quantity tp: 4 - fp: 1 - fn: 2 - precision: 0.8000 - recall: 0.6667 - f1-score: 0.7273 seats tp: 0 - fp: 1 - fn: 0 - precision: 0.0000 - recall: 0.0000 - f1-score: 0.0000 time tp: 1 - fp: 4 - fn: 5 - precision: 0.2000 - recall: 0.1667 - f1-score: 0.1818 title tp: 3 - fp: 3 - fn: 4 - precision: 0.5000 - recall: 0.4286 - f1-score: 0.4615 2022-05-01 12:16:22,740 ----------------------------------------------------------------------------------------------------
{'test_score': 0.3548387096774194, 'dev_score_history': [0.0, 0.06451612903225806, 0.0, 0.17142857142857143, 0.16666666666666663, 0.23809523809523808, 0.3181818181818182, 0.38461538461538464, 0.5517241379310345, 0.47058823529411764], 'train_loss_history': [6.525583657351407, 5.26294283433394, 4.7632177526300605, 4.261161284013228, 3.807387958873402, 3.2323263558474453, 2.828585754741322, 2.6404500982978125, 2.2732795260169287, 2.0080170089548286], 'dev_loss_history': [8.419286727905273, 7.168168544769287, 7.209894180297852, 5.882441997528076, 5.224854469299316, 4.557621002197266, 4.020608901977539, 3.542769432067871, 3.4634602069854736, 3.1866908073425293]}
Jakość wyuczonego modelu możemy ocenić, korzystając z zaraportowanych powyżej metryk, tj.:
_tp (true positives)
liczba słów oznaczonych w zbiorze testowym etykietą $e$, które model oznaczył tą etykietą
_fp (false positives)
liczba słów nieoznaczonych w zbiorze testowym etykietą $e$, które model oznaczył tą etykietą
_fn (false negatives)
liczba słów oznaczonych w zbiorze testowym etykietą $e$, którym model nie nadał etykiety $e$
_precision
$$\frac{tp}{tp + fp}$$
_recall
$$\frac{tp}{tp + fn}$$
$F_1$
$$\frac{2 \cdot precision \cdot recall}{precision + recall}$$
_micro $F_1$
$F_1$ w którym $tp$, $fp$ i $fn$ są liczone łącznie dla wszystkich etykiet, tj. $tp = \sum_{e}{{tp}_e}$, $fn = \sum{e}{{fn}e}$, $fp = \sum{e}{{fp}_e}$
_macro $F_1$
średnia arytmetyczna z $F_1$ obliczonych dla poszczególnych etykiet z osobna.
Wyuczony model możemy wczytać z pliku korzystając z metody load
.
model = SequenceTagger.load('slot-model/final-model.pt')
2022-05-01 12:16:22,953 loading file slot-model/final-model.pt
Wczytany model możemy wykorzystać do przewidywania slotów w wypowiedziach użytkownika, korzystając
z przedstawionej poniżej funkcji predict
.
def predict(model, sentence):
csentence = [{'form': word} for word in sentence]
fsentence = conllu2flair([csentence])[0]
model.predict(fsentence)
return [(token, ftoken.get_tag('slot').value) for token, ftoken in zip(sentence, fsentence)]
predict(model, 'kiedy gracie film zorro'.split())
[('kiedy', 'O'), ('gracie', 'O'), ('film', 'O'), ('zorro', 'B-title')]
Jak pokazuje przykład poniżej model wyuczony tylko na 100 przykładach popełnia w dosyć prostej
wypowiedzi błąd etykietując słowo alarm
tagiem B-weather/noun
.
tabulate(predict(model, 'kiedy gracie film zorro'.split()), tablefmt='html')
kiedy | O |
gracie | O |
film | O |
zorro | B-title |
Literatura
- Sebastian Schuster, Sonal Gupta, Rushin Shah, Mike Lewis, Cross-lingual Transfer Learning for Multilingual Task Oriented Dialog. NAACL-HLT (1) 2019, pp. 3795-3805
- John D. Lafferty, Andrew McCallum, and Fernando C. N. Pereira. 2001. Conditional Random Fields: Probabilistic Models for Segmenting and Labeling Sequence Data. In Proceedings of the Eighteenth International Conference on Machine Learning (ICML '01). Morgan Kaufmann Publishers Inc., San Francisco, CA, USA, 282–289, https://repository.upenn.edu/cgi/viewcontent.cgi?article=1162&context=cis_papers
- Sepp Hochreiter and Jürgen Schmidhuber. 1997. Long Short-Term Memory. Neural Comput. 9, 8 (November 15, 1997), 1735–1780, https://doi.org/10.1162/neco.1997.9.8.1735
- Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, Illia Polosukhin, Attention is All you Need, NIPS 2017, pp. 5998-6008, https://arxiv.org/abs/1706.03762
- Alan Akbik, Duncan Blythe, Roland Vollgraf, Contextual String Embeddings for Sequence Labeling, Proceedings of the 27th International Conference on Computational Linguistics, pp. 1638–1649, https://www.aclweb.org/anthology/C18-1139.pdf