SystemyDialogowe/08-parsing-semantyczny-uczenie.ipynb
2022-05-18 00:06:14 +02:00

42 KiB
Raw Blame History

Logo 1

Systemy Dialogowe

8. Parsing semantyczny z wykorzystaniem technik uczenia maszynowego [laboratoria]

Marek Kubis (2021)

Logo 2

Parsing semantyczny z wykorzystaniem technik uczenia maszynowego

Wprowadzenie

Problem wykrywania slotów i ich wartości w wypowiedziach użytkownika można sformułować jako zadanie polegające na przewidywaniu dla poszczególnych słów etykiet wskazujących na to czy i do jakiego slotu dane słowo należy.

chciałbym zarezerwować stolik na jutro**/day** na godzinę dwunastą**/hour** czterdzieści**/hour** pięć**/hour** na pięć**/size** osób

Granice slotów oznacza się korzystając z wybranego schematu etykietowania.

Schemat IOB

Prefix Znaczenie
I wnętrze slotu (inside)
O poza slotem (outside)
B początek slotu (beginning)

chciałbym zarezerwować stolik na jutro**/B-day** na godzinę dwunastą**/B-hour** czterdzieści**/I-hour** pięć**/I-hour** na pięć**/B-size** osób

Schemat IOBES

Prefix Znaczenie
I wnętrze slotu (inside)
O poza slotem (outside)
B początek slotu (beginning)
E koniec slotu (ending)
S pojedyncze słowo (single)

chciałbym zarezerwować stolik na jutro**/S-day** na godzinę dwunastą**/B-hour** czterdzieści**/I-hour** pięć**/E-hour** na pięć**/S-size** osób

Jeżeli dla tak sformułowanego zadania przygotujemy zbiór danych złożony z wypowiedzi użytkownika z oznaczonymi slotami (tzw. _zbiór uczący), to możemy zastosować techniki (nadzorowanego) uczenia maszynowego w celu zbudowania modelu annotującego wypowiedzi użytkownika etykietami slotów.

Do zbudowania takiego modelu można wykorzystać między innymi:

  1. warunkowe pola losowe (Lafferty i in.; 2001),

  2. rekurencyjne sieci neuronowe, np. sieci LSTM (Hochreiter i Schmidhuber; 1997),

  3. transformery (Vaswani i in., 2017).

Przykład

Skorzystamy ze zbioru danych przygotowanego przez Schustera (2019).

!mkdir -p l07
%cd l07
!curl -L -C -  https://fb.me/multilingual_task_oriented_data  -o data.zip
!unzip data.zip
%cd ..

Zbiór ten gromadzi wypowiedzi w trzech językach opisane slotami dla dwunastu ram należących do trzech dziedzin Alarm, Reminder oraz Weather. Dane wczytamy korzystając z biblioteki conllu.

from conllu import parse_incr
fields = ['id', 'form', 'frame', 'slot']

def nolabel2o(line, i):
    return 'O' if line[i] == 'NoLabel' else line[i]

with open('./train_data//train.conllu') as trainfile:
    trainset = list(parse_incr(trainfile, fields=fields, field_parsers={'slot': nolabel2o}))
with open('./train_data//test.conllu') as testfile:
    testset = list(parse_incr(testfile, fields=fields, field_parsers={'slot': nolabel2o}))

Zobaczmy kilka przykładowych wypowiedzi z tego zbioru.

from tabulate import tabulate
tabulate(trainset[26], tablefmt='html')
1chciaĹbymO
2kupić O
3popcorn O
tabulate(trainset[1000], tablefmt='html')
tabulate(trainset[2000], tablefmt='html')

Na potrzeby prezentacji procesu uczenia w jupyterowym notatniku zawęzimy zbiór danych do początkowych przykładów.

trainset = trainset[:100]
testset = testset[:100]

Budując model skorzystamy z architektury opartej o rekurencyjne sieci neuronowe zaimplementowanej w bibliotece flair (Akbik i in. 2018).

from flair.data import Corpus, Sentence, Token
from flair.datasets import SentenceDataset
from flair.embeddings import StackedEmbeddings
from flair.embeddings import WordEmbeddings
from flair.embeddings import CharacterEmbeddings
from flair.embeddings import FlairEmbeddings
from flair.models import SequenceTagger
from flair.trainers import ModelTrainer

# determinizacja obliczeń
import random
import torch
random.seed(42)
torch.manual_seed(42)

if torch.cuda.is_available():
    torch.cuda.manual_seed(0)
    torch.cuda.manual_seed_all(0)
    torch.backends.cudnn.enabled = False
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
C:\Users\Adrian\AppData\Roaming\Python\Python37\site-packages\tqdm\auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm

Dane skonwertujemy do formatu wykorzystywanego przez flair, korzystając z następującej funkcji.

def conllu2flair(sentences, label=None):
    fsentences = []

    for sentence in sentences:
        fsentence = Sentence()

        for token in sentence:
            ftoken = Token(token['form'])

            if label:
                ftoken.add_tag(label, token[label])

            fsentence.add_token(ftoken)

        fsentences.append(fsentence)

    return SentenceDataset(fsentences)

corpus = Corpus(train=conllu2flair(trainset, 'slot'), test=conllu2flair(testset, 'slot'))
print(corpus)
tag_dictionary = corpus.make_tag_dictionary(tag_type='slot')
print(tag_dictionary)
Corpus: 194 train + 22 dev + 33 test sentences
Dictionary with 12 tags: <unk>, O, B-time, I-time, B-area, I-area, B-quantity, B-date, I-quantity, I-date, <START>, <STOP>

Nasz model będzie wykorzystywał wektorowe reprezentacje słów (zob. Word Embeddings).

embedding_types = [
    WordEmbeddings('pl'),
    FlairEmbeddings('polish-forward'),
    FlairEmbeddings('polish-backward'),
    CharacterEmbeddings(),
]

embeddings = StackedEmbeddings(embeddings=embedding_types)
tagger = SequenceTagger(hidden_size=256, embeddings=embeddings,
                        tag_dictionary=tag_dictionary,
                        tag_type='slot', use_crf=True)
2022-05-17 23:51:31,428 https://flair.informatik.hu-berlin.de/resources/embeddings/token/pl-wiki-fasttext-300d-1M.vectors.npy not found in cache, downloading to C:\Users\Adrian\AppData\Local\Temp\tmpdtf6je0q
100%|██████████| 1199998928/1199998928 [00:38<00:00, 31059047.54B/s]
2022-05-17 23:52:10,221 copying C:\Users\Adrian\AppData\Local\Temp\tmpdtf6je0q to cache at C:\Users\Adrian\.flair\embeddings\pl-wiki-fasttext-300d-1M.vectors.npy
2022-05-17 23:52:11,581 removing temp file C:\Users\Adrian\AppData\Local\Temp\tmpdtf6je0q
2022-05-17 23:52:11,834 https://flair.informatik.hu-berlin.de/resources/embeddings/token/pl-wiki-fasttext-300d-1M not found in cache, downloading to C:\Users\Adrian\AppData\Local\Temp\tmpncdt74ud
100%|██████████| 40874795/40874795 [00:01<00:00, 25496548.48B/s]
2022-05-17 23:52:13,623 copying C:\Users\Adrian\AppData\Local\Temp\tmpncdt74ud to cache at C:\Users\Adrian\.flair\embeddings\pl-wiki-fasttext-300d-1M
2022-05-17 23:52:13,678 removing temp file C:\Users\Adrian\AppData\Local\Temp\tmpncdt74ud
2022-05-17 23:52:21,696 https://flair.informatik.hu-berlin.de/resources/embeddings/flair/lm-polish-forward-v0.2.pt not found in cache, downloading to C:\Users\Adrian\AppData\Local\Temp\tmp6okeka8n
100%|██████████| 84244196/84244196 [00:02<00:00, 35143826.68B/s]
2022-05-17 23:52:24,338 copying C:\Users\Adrian\AppData\Local\Temp\tmp6okeka8n to cache at C:\Users\Adrian\.flair\embeddings\lm-polish-forward-v0.2.pt
2022-05-17 23:52:24,435 removing temp file C:\Users\Adrian\AppData\Local\Temp\tmp6okeka8n
2022-05-17 23:52:24,857 https://flair.informatik.hu-berlin.de/resources/embeddings/flair/lm-polish-backward-v0.2.pt not found in cache, downloading to C:\Users\Adrian\AppData\Local\Temp\tmp_6ut1zi9
100%|██████████| 84244196/84244196 [00:02<00:00, 35815492.94B/s]
2022-05-17 23:52:27,375 copying C:\Users\Adrian\AppData\Local\Temp\tmp_6ut1zi9 to cache at C:\Users\Adrian\.flair\embeddings\lm-polish-backward-v0.2.pt
2022-05-17 23:52:27,460 removing temp file C:\Users\Adrian\AppData\Local\Temp\tmp_6ut1zi9

Zobaczmy jak wygląda architektura sieci neuronowej, która będzie odpowiedzialna za przewidywanie slotów w wypowiedziach.

print(tagger)
SequenceTagger(
  (embeddings): StackedEmbeddings(
    (list_embedding_0): WordEmbeddings('pl')
    (list_embedding_1): FlairEmbeddings(
      (lm): LanguageModel(
        (drop): Dropout(p=0.25, inplace=False)
        (encoder): Embedding(1602, 100)
        (rnn): LSTM(100, 2048)
        (decoder): Linear(in_features=2048, out_features=1602, bias=True)
      )
    )
    (list_embedding_2): FlairEmbeddings(
      (lm): LanguageModel(
        (drop): Dropout(p=0.25, inplace=False)
        (encoder): Embedding(1602, 100)
        (rnn): LSTM(100, 2048)
        (decoder): Linear(in_features=2048, out_features=1602, bias=True)
      )
    )
    (list_embedding_3): CharacterEmbeddings(
      (char_embedding): Embedding(275, 25)
      (char_rnn): LSTM(25, 25, bidirectional=True)
    )
  )
  (word_dropout): WordDropout(p=0.05)
  (locked_dropout): LockedDropout(p=0.5)
  (embedding2nn): Linear(in_features=4446, out_features=4446, bias=True)
  (rnn): LSTM(4446, 256, batch_first=True, bidirectional=True)
  (linear): Linear(in_features=512, out_features=12, bias=True)
  (beta): 1.0
  (weights): None
  (weight_tensor) None
)

Wykonamy dziesięć iteracji (epok) uczenia a wynikowy model zapiszemy w katalogu slot-model.

trainer = ModelTrainer(tagger, corpus)
trainer.train('slot-model',
              learning_rate=0.1,
              mini_batch_size=32,
              max_epochs=10,
              train_with_dev=False)
2022-05-17 23:52:57,432 ----------------------------------------------------------------------------------------------------
2022-05-17 23:52:57,433 Model: "SequenceTagger(
  (embeddings): StackedEmbeddings(
    (list_embedding_0): WordEmbeddings('pl')
    (list_embedding_1): FlairEmbeddings(
      (lm): LanguageModel(
        (drop): Dropout(p=0.25, inplace=False)
        (encoder): Embedding(1602, 100)
        (rnn): LSTM(100, 2048)
        (decoder): Linear(in_features=2048, out_features=1602, bias=True)
      )
    )
    (list_embedding_2): FlairEmbeddings(
      (lm): LanguageModel(
        (drop): Dropout(p=0.25, inplace=False)
        (encoder): Embedding(1602, 100)
        (rnn): LSTM(100, 2048)
        (decoder): Linear(in_features=2048, out_features=1602, bias=True)
      )
    )
    (list_embedding_3): CharacterEmbeddings(
      (char_embedding): Embedding(275, 25)
      (char_rnn): LSTM(25, 25, bidirectional=True)
    )
  )
  (word_dropout): WordDropout(p=0.05)
  (locked_dropout): LockedDropout(p=0.5)
  (embedding2nn): Linear(in_features=4446, out_features=4446, bias=True)
  (rnn): LSTM(4446, 256, batch_first=True, bidirectional=True)
  (linear): Linear(in_features=512, out_features=12, bias=True)
  (beta): 1.0
  (weights): None
  (weight_tensor) None
)"
2022-05-17 23:52:57,434 ----------------------------------------------------------------------------------------------------
2022-05-17 23:52:57,435 Corpus: "Corpus: 194 train + 22 dev + 33 test sentences"
2022-05-17 23:52:57,435 ----------------------------------------------------------------------------------------------------
2022-05-17 23:52:57,435 Parameters:
2022-05-17 23:52:57,436  - learning_rate: "0.1"
2022-05-17 23:52:57,437  - mini_batch_size: "32"
2022-05-17 23:52:57,437  - patience: "3"
2022-05-17 23:52:57,437  - anneal_factor: "0.5"
2022-05-17 23:52:57,438  - max_epochs: "10"
2022-05-17 23:52:57,439  - shuffle: "True"
2022-05-17 23:52:57,440  - train_with_dev: "False"
2022-05-17 23:52:57,440  - batch_growth_annealing: "False"
2022-05-17 23:52:57,441 ----------------------------------------------------------------------------------------------------
2022-05-17 23:52:57,441 Model training base path: "slot-model"
2022-05-17 23:52:57,442 ----------------------------------------------------------------------------------------------------
2022-05-17 23:52:57,443 Device: cpu
2022-05-17 23:52:57,443 ----------------------------------------------------------------------------------------------------
2022-05-17 23:52:57,444 Embeddings storage mode: cpu
2022-05-17 23:52:57,446 ----------------------------------------------------------------------------------------------------
2022-05-17 23:52:59,206 epoch 1 - iter 1/7 - loss 16.77810669 - samples/sec: 18.23 - lr: 0.100000
2022-05-17 23:53:01,036 epoch 1 - iter 2/7 - loss 15.17136908 - samples/sec: 17.51 - lr: 0.100000
2022-05-17 23:53:02,450 epoch 1 - iter 3/7 - loss 13.45863914 - samples/sec: 22.63 - lr: 0.100000
2022-05-17 23:53:04,163 epoch 1 - iter 4/7 - loss 11.81387305 - samples/sec: 18.70 - lr: 0.100000
2022-05-17 23:53:06,030 epoch 1 - iter 5/7 - loss 10.41218300 - samples/sec: 17.14 - lr: 0.100000
2022-05-17 23:53:07,655 epoch 1 - iter 6/7 - loss 9.20362504 - samples/sec: 19.70 - lr: 0.100000
2022-05-17 23:53:07,968 epoch 1 - iter 7/7 - loss 8.10721644 - samples/sec: 102.61 - lr: 0.100000
2022-05-17 23:53:07,969 ----------------------------------------------------------------------------------------------------
2022-05-17 23:53:07,970 EPOCH 1 done: loss 8.1072 - lr 0.1000000
2022-05-17 23:53:09,606 DEV : loss 3.991352081298828 - score 0.2
2022-05-17 23:53:09,607 BAD EPOCHS (no improvement): 0
saving best model
2022-05-17 23:53:14,975 ----------------------------------------------------------------------------------------------------
2022-05-17 23:53:15,484 epoch 2 - iter 1/7 - loss 3.58558130 - samples/sec: 63.03 - lr: 0.100000
2022-05-17 23:53:15,865 epoch 2 - iter 2/7 - loss 3.12797976 - samples/sec: 84.20 - lr: 0.100000
2022-05-17 23:53:16,267 epoch 2 - iter 3/7 - loss 2.60615242 - samples/sec: 79.80 - lr: 0.100000
2022-05-17 23:53:16,738 epoch 2 - iter 4/7 - loss 2.71958175 - samples/sec: 68.18 - lr: 0.100000
2022-05-17 23:53:17,170 epoch 2 - iter 5/7 - loss 2.70331609 - samples/sec: 74.26 - lr: 0.100000
2022-05-17 23:53:17,603 epoch 2 - iter 6/7 - loss 2.51522466 - samples/sec: 74.01 - lr: 0.100000
2022-05-17 23:53:17,748 epoch 2 - iter 7/7 - loss 2.19215042 - samples/sec: 221.61 - lr: 0.100000
2022-05-17 23:53:17,749 ----------------------------------------------------------------------------------------------------
2022-05-17 23:53:17,750 EPOCH 2 done: loss 2.1922 - lr 0.1000000
2022-05-17 23:53:17,844 DEV : loss 3.9842920303344727 - score 0.3636
2022-05-17 23:53:17,846 BAD EPOCHS (no improvement): 0
saving best model
2022-05-17 23:53:22,865 ----------------------------------------------------------------------------------------------------
2022-05-17 23:53:23,305 epoch 3 - iter 1/7 - loss 2.19582605 - samples/sec: 72.76 - lr: 0.100000
2022-05-17 23:53:23,741 epoch 3 - iter 2/7 - loss 1.85529530 - samples/sec: 73.58 - lr: 0.100000
2022-05-17 23:53:24,212 epoch 3 - iter 3/7 - loss 1.91948136 - samples/sec: 68.09 - lr: 0.100000
2022-05-17 23:53:24,717 epoch 3 - iter 4/7 - loss 2.11527669 - samples/sec: 63.50 - lr: 0.100000
2022-05-17 23:53:25,129 epoch 3 - iter 5/7 - loss 2.12587404 - samples/sec: 77.75 - lr: 0.100000
2022-05-17 23:53:25,630 epoch 3 - iter 6/7 - loss 2.01592445 - samples/sec: 63.92 - lr: 0.100000
2022-05-17 23:53:25,755 epoch 3 - iter 7/7 - loss 1.73551549 - samples/sec: 258.75 - lr: 0.100000
2022-05-17 23:53:25,756 ----------------------------------------------------------------------------------------------------
2022-05-17 23:53:25,757 EPOCH 3 done: loss 1.7355 - lr 0.1000000
2022-05-17 23:53:25,854 DEV : loss 3.3194284439086914 - score 0.3077
2022-05-17 23:53:25,855 BAD EPOCHS (no improvement): 1
2022-05-17 23:53:25,856 ----------------------------------------------------------------------------------------------------
2022-05-17 23:53:26,274 epoch 4 - iter 1/7 - loss 1.46010232 - samples/sec: 76.66 - lr: 0.100000
2022-05-17 23:53:26,734 epoch 4 - iter 2/7 - loss 1.18807647 - samples/sec: 69.66 - lr: 0.100000
2022-05-17 23:53:27,229 epoch 4 - iter 3/7 - loss 1.33144226 - samples/sec: 64.87 - lr: 0.100000
2022-05-17 23:53:27,775 epoch 4 - iter 4/7 - loss 1.64428358 - samples/sec: 58.69 - lr: 0.100000
2022-05-17 23:53:28,243 epoch 4 - iter 5/7 - loss 1.62551130 - samples/sec: 68.71 - lr: 0.100000
2022-05-17 23:53:28,727 epoch 4 - iter 6/7 - loss 1.74551653 - samples/sec: 66.25 - lr: 0.100000
2022-05-17 23:53:28,856 epoch 4 - iter 7/7 - loss 1.53921426 - samples/sec: 248.73 - lr: 0.100000
2022-05-17 23:53:28,857 ----------------------------------------------------------------------------------------------------
2022-05-17 23:53:28,858 EPOCH 4 done: loss 1.5392 - lr 0.1000000
2022-05-17 23:53:28,962 DEV : loss 2.8986825942993164 - score 0.2857
2022-05-17 23:53:28,963 BAD EPOCHS (no improvement): 2
2022-05-17 23:53:28,965 ----------------------------------------------------------------------------------------------------
2022-05-17 23:53:29,417 epoch 5 - iter 1/7 - loss 1.72827125 - samples/sec: 70.90 - lr: 0.100000
2022-05-17 23:53:29,902 epoch 5 - iter 2/7 - loss 1.51951337 - samples/sec: 66.07 - lr: 0.100000
2022-05-17 23:53:30,355 epoch 5 - iter 3/7 - loss 1.55555471 - samples/sec: 70.83 - lr: 0.100000
2022-05-17 23:53:30,840 epoch 5 - iter 4/7 - loss 1.31492138 - samples/sec: 66.16 - lr: 0.100000
2022-05-17 23:53:31,257 epoch 5 - iter 5/7 - loss 1.46497860 - samples/sec: 76.92 - lr: 0.100000
2022-05-17 23:53:31,768 epoch 5 - iter 6/7 - loss 1.60987592 - samples/sec: 62.75 - lr: 0.100000
2022-05-17 23:53:31,929 epoch 5 - iter 7/7 - loss 2.72113044 - samples/sec: 200.53 - lr: 0.100000
2022-05-17 23:53:31,930 ----------------------------------------------------------------------------------------------------
2022-05-17 23:53:31,931 EPOCH 5 done: loss 2.7211 - lr 0.1000000
2022-05-17 23:53:32,024 DEV : loss 2.766446590423584 - score 0.3077
2022-05-17 23:53:32,025 BAD EPOCHS (no improvement): 3
2022-05-17 23:53:32,026 ----------------------------------------------------------------------------------------------------
2022-05-17 23:53:32,475 epoch 6 - iter 1/7 - loss 1.68398678 - samples/sec: 71.62 - lr: 0.100000
2022-05-17 23:53:32,971 epoch 6 - iter 2/7 - loss 1.67541099 - samples/sec: 64.62 - lr: 0.100000
2022-05-17 23:53:33,400 epoch 6 - iter 3/7 - loss 1.58060956 - samples/sec: 74.78 - lr: 0.100000
2022-05-17 23:53:33,878 epoch 6 - iter 4/7 - loss 1.55456299 - samples/sec: 66.92 - lr: 0.100000
2022-05-17 23:53:34,278 epoch 6 - iter 5/7 - loss 1.50003145 - samples/sec: 80.28 - lr: 0.100000
2022-05-17 23:53:34,813 epoch 6 - iter 6/7 - loss 1.46878848 - samples/sec: 60.04 - lr: 0.100000
2022-05-17 23:53:34,951 epoch 6 - iter 7/7 - loss 1.66172016 - samples/sec: 233.22 - lr: 0.100000
2022-05-17 23:53:34,952 ----------------------------------------------------------------------------------------------------
2022-05-17 23:53:34,952 EPOCH 6 done: loss 1.6617 - lr 0.1000000
2022-05-17 23:53:35,040 DEV : loss 2.2595832347869873 - score 0.2857
Epoch     6: reducing learning rate of group 0 to 5.0000e-02.
2022-05-17 23:53:35,041 BAD EPOCHS (no improvement): 4
2022-05-17 23:53:35,043 ----------------------------------------------------------------------------------------------------
2022-05-17 23:53:35,461 epoch 7 - iter 1/7 - loss 1.14667833 - samples/sec: 76.93 - lr: 0.050000
2022-05-17 23:53:35,976 epoch 7 - iter 2/7 - loss 1.11618459 - samples/sec: 62.22 - lr: 0.050000
2022-05-17 23:53:36,416 epoch 7 - iter 3/7 - loss 1.24378494 - samples/sec: 72.88 - lr: 0.050000
2022-05-17 23:53:36,880 epoch 7 - iter 4/7 - loss 1.31663331 - samples/sec: 69.14 - lr: 0.050000
2022-05-17 23:53:37,298 epoch 7 - iter 5/7 - loss 1.39581544 - samples/sec: 76.75 - lr: 0.050000
2022-05-17 23:53:37,714 epoch 7 - iter 6/7 - loss 1.34690581 - samples/sec: 77.09 - lr: 0.050000
2022-05-17 23:53:37,860 epoch 7 - iter 7/7 - loss 1.46004195 - samples/sec: 220.36 - lr: 0.050000
2022-05-17 23:53:37,861 ----------------------------------------------------------------------------------------------------
2022-05-17 23:53:37,861 EPOCH 7 done: loss 1.4600 - lr 0.0500000
2022-05-17 23:53:37,954 DEV : loss 2.200728416442871 - score 0.2857
2022-05-17 23:53:37,955 BAD EPOCHS (no improvement): 1
2022-05-17 23:53:37,956 ----------------------------------------------------------------------------------------------------
2022-05-17 23:53:38,423 epoch 8 - iter 1/7 - loss 1.14459288 - samples/sec: 68.83 - lr: 0.050000
2022-05-17 23:53:38,805 epoch 8 - iter 2/7 - loss 0.95714736 - samples/sec: 83.88 - lr: 0.050000
2022-05-17 23:53:39,302 epoch 8 - iter 3/7 - loss 1.17704646 - samples/sec: 64.42 - lr: 0.050000
2022-05-17 23:53:39,781 epoch 8 - iter 4/7 - loss 1.29963121 - samples/sec: 66.92 - lr: 0.050000
2022-05-17 23:53:40,256 epoch 8 - iter 5/7 - loss 1.34262223 - samples/sec: 67.59 - lr: 0.050000
2022-05-17 23:53:40,704 epoch 8 - iter 6/7 - loss 1.33356750 - samples/sec: 71.53 - lr: 0.050000
2022-05-17 23:53:40,846 epoch 8 - iter 7/7 - loss 1.20113390 - samples/sec: 226.59 - lr: 0.050000
2022-05-17 23:53:40,847 ----------------------------------------------------------------------------------------------------
2022-05-17 23:53:40,848 EPOCH 8 done: loss 1.2011 - lr 0.0500000
2022-05-17 23:53:40,941 DEV : loss 2.4227261543273926 - score 0.2857
2022-05-17 23:53:40,942 BAD EPOCHS (no improvement): 2
2022-05-17 23:53:40,943 ----------------------------------------------------------------------------------------------------
2022-05-17 23:53:41,389 epoch 9 - iter 1/7 - loss 1.12297106 - samples/sec: 71.73 - lr: 0.050000
2022-05-17 23:53:41,800 epoch 9 - iter 2/7 - loss 0.92356640 - samples/sec: 78.01 - lr: 0.050000
2022-05-17 23:53:42,249 epoch 9 - iter 3/7 - loss 1.02407436 - samples/sec: 71.37 - lr: 0.050000
2022-05-17 23:53:42,667 epoch 9 - iter 4/7 - loss 1.04805315 - samples/sec: 76.71 - lr: 0.050000
2022-05-17 23:53:43,215 epoch 9 - iter 5/7 - loss 1.33371143 - samples/sec: 58.59 - lr: 0.050000
2022-05-17 23:53:43,661 epoch 9 - iter 6/7 - loss 1.27829826 - samples/sec: 71.89 - lr: 0.050000
2022-05-17 23:53:43,796 epoch 9 - iter 7/7 - loss 1.10260926 - samples/sec: 240.25 - lr: 0.050000
2022-05-17 23:53:43,797 ----------------------------------------------------------------------------------------------------
2022-05-17 23:53:43,798 EPOCH 9 done: loss 1.1026 - lr 0.0500000
2022-05-17 23:53:43,895 DEV : loss 2.1707162857055664 - score 0.3077
2022-05-17 23:53:43,896 BAD EPOCHS (no improvement): 3
2022-05-17 23:53:43,903 ----------------------------------------------------------------------------------------------------
2022-05-17 23:53:44,338 epoch 10 - iter 1/7 - loss 1.34320462 - samples/sec: 73.74 - lr: 0.050000
2022-05-17 23:53:44,808 epoch 10 - iter 2/7 - loss 0.96772069 - samples/sec: 68.25 - lr: 0.050000
2022-05-17 23:53:45,207 epoch 10 - iter 3/7 - loss 1.06257542 - samples/sec: 80.34 - lr: 0.050000
2022-05-17 23:53:45,729 epoch 10 - iter 4/7 - loss 0.92318819 - samples/sec: 61.50 - lr: 0.050000
2022-05-17 23:53:46,202 epoch 10 - iter 5/7 - loss 1.08295707 - samples/sec: 67.82 - lr: 0.050000
2022-05-17 23:53:46,707 epoch 10 - iter 6/7 - loss 1.18012399 - samples/sec: 63.49 - lr: 0.050000
2022-05-17 23:53:46,841 epoch 10 - iter 7/7 - loss 1.01267667 - samples/sec: 239.34 - lr: 0.050000
2022-05-17 23:53:46,842 ----------------------------------------------------------------------------------------------------
2022-05-17 23:53:46,842 EPOCH 10 done: loss 1.0127 - lr 0.0500000
2022-05-17 23:53:46,942 DEV : loss 1.9863343238830566 - score 0.3077
Epoch    10: reducing learning rate of group 0 to 2.5000e-02.
2022-05-17 23:53:46,943 BAD EPOCHS (no improvement): 4
2022-05-17 23:53:51,951 ----------------------------------------------------------------------------------------------------
2022-05-17 23:53:51,952 Testing using best model ...
2022-05-17 23:53:51,953 loading file slot-model\best-model.pt
2022-05-17 23:53:57,745 0.8000	0.2667	0.4000
2022-05-17 23:53:57,746 
Results:
- F1-score (micro) 0.4000
- F1-score (macro) 0.2424

By class:
date       tp: 0 - fp: 0 - fn: 4 - precision: 0.0000 - recall: 0.0000 - f1-score: 0.0000
quantity   tp: 4 - fp: 1 - fn: 2 - precision: 0.8000 - recall: 0.6667 - f1-score: 0.7273
time       tp: 0 - fp: 0 - fn: 5 - precision: 0.0000 - recall: 0.0000 - f1-score: 0.0000
2022-05-17 23:53:57,747 ----------------------------------------------------------------------------------------------------
{'test_score': 0.4,
 'dev_score_history': [0.2,
  0.36363636363636365,
  0.3076923076923077,
  0.28571428571428575,
  0.3076923076923077,
  0.28571428571428575,
  0.28571428571428575,
  0.28571428571428575,
  0.3076923076923077,
  0.3076923076923077],
 'train_loss_history': [8.107216443334307,
  2.19215042250497,
  1.735515492303031,
  1.5392142619405473,
  2.721130439213344,
  1.6617201566696167,
  1.460041948727199,
  1.2011338983263289,
  1.1026092597416468,
  1.012676673276084],
 'dev_loss_history': [3.991352081298828,
  3.9842920303344727,
  3.3194284439086914,
  2.8986825942993164,
  2.766446590423584,
  2.2595832347869873,
  2.200728416442871,
  2.4227261543273926,
  2.1707162857055664,
  1.9863343238830566]}

Jakość wyuczonego modelu możemy ocenić, korzystając z zaraportowanych powyżej metryk, tj.:

  • _tp (true positives)

    liczba słów oznaczonych w zbiorze testowym etykietą $e$, które model oznaczył tą etykietą

  • _fp (false positives)

    liczba słów nieoznaczonych w zbiorze testowym etykietą $e$, które model oznaczył tą etykietą

  • _fn (false negatives)

    liczba słów oznaczonych w zbiorze testowym etykietą $e$, którym model nie nadał etykiety $e$

  • _precision

    $$\frac{tp}{tp + fp}$$

  • _recall

    $$\frac{tp}{tp + fn}$$

  • $F_1$

    $$\frac{2 \cdot precision \cdot recall}{precision + recall}$$

  • _micro $F_1$

    $F_1$ w którym $tp$, $fp$ i $fn$ są liczone łącznie dla wszystkich etykiet, tj. $tp = \sum_{e}{{tp}_e}$, $fn = \sum{e}{{fn}e}$, $fp = \sum{e}{{fp}_e}$

  • _macro $F_1$

    średnia arytmetyczna z $F_1$ obliczonych dla poszczególnych etykiet z osobna.

Wyuczony model możemy wczytać z pliku korzystając z metody load.

model = SequenceTagger.load('slot-model/final-model.pt')
2022-05-17 23:57:03,014 loading file slot-model/final-model.pt

Wczytany model możemy wykorzystać do przewidywania slotów w wypowiedziach użytkownika, korzystając z przedstawionej poniżej funkcji predict.

def predict(model, sentence):
    csentence = [{'form': word} for word in sentence]
    fsentence = conllu2flair([csentence])[0]
    model.predict(fsentence)
    return [(token, ftoken.get_tag('slot').value) for token, ftoken in zip(sentence, fsentence)]

Jak pokazuje przykład poniżej model wyuczony tylko na 100 przykładach popełnia w dosyć prostej wypowiedzi błąd etykietując słowo alarm tagiem B-weather/noun.

tabulate(predict(model, 'chciałbym zarezerwować 2 bilety na batman na 19:30 na środku z tyłun po prawej i po lewej nie chce z przodu'.split()), tablefmt='html')
chciałbym O
zarezerwowaćO
2 B-quantity
bilety O
na O
batman O
na O
19:30 B-quantity
na O
środku O
z O
tyłun O
po O
prawej O
i O
po O
lewej O
nie O
chce O
z O
przodu O

Literatura

  1. Sebastian Schuster, Sonal Gupta, Rushin Shah, Mike Lewis, Cross-lingual Transfer Learning for Multilingual Task Oriented Dialog. NAACL-HLT (1) 2019, pp. 3795-3805
  2. John D. Lafferty, Andrew McCallum, and Fernando C. N. Pereira. 2001. Conditional Random Fields: Probabilistic Models for Segmenting and Labeling Sequence Data. In Proceedings of the Eighteenth International Conference on Machine Learning (ICML '01). Morgan Kaufmann Publishers Inc., San Francisco, CA, USA, 282289, https://repository.upenn.edu/cgi/viewcontent.cgi?article=1162&context=cis_papers
  3. Sepp Hochreiter and Jürgen Schmidhuber. 1997. Long Short-Term Memory. Neural Comput. 9, 8 (November 15, 1997), 17351780, https://doi.org/10.1162/neco.1997.9.8.1735
  4. Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, Illia Polosukhin, Attention is All you Need, NIPS 2017, pp. 5998-6008, https://arxiv.org/abs/1706.03762
  5. Alan Akbik, Duncan Blythe, Roland Vollgraf, Contextual String Embeddings for Sequence Labeling, Proceedings of the 27th International Conference on Computational Linguistics, pp. 16381649, https://www.aclweb.org/anthology/C18-1139.pdf