322 KiB
322 KiB
!pip -q install simpletransformers
[2K [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m250.5/250.5 KB[0m [31m18.2 MB/s[0m eta [36m0:00:00[0m [2K [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.6/7.6 MB[0m [31m105.3 MB/s[0m eta [36m0:00:00[0m [2K [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.0/2.0 MB[0m [31m81.7 MB/s[0m eta [36m0:00:00[0m [2K [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m452.9/452.9 KB[0m [31m47.6 MB/s[0m eta [36m0:00:00[0m [2K [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.3/1.3 MB[0m [31m78.6 MB/s[0m eta [36m0:00:00[0m [2K [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.6/43.6 KB[0m [31m5.6 MB/s[0m eta [36m0:00:00[0m [?25h Preparing metadata (setup.py) ... [?25l[?25hdone [2K [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.8/5.8 MB[0m [31m93.2 MB/s[0m eta [36m0:00:00[0m [2K [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.3/9.3 MB[0m [31m108.7 MB/s[0m eta [36m0:00:00[0m [2K [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m182.4/182.4 KB[0m [31m22.5 MB/s[0m eta [36m0:00:00[0m [2K [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m184.0/184.0 KB[0m [31m829.4 kB/s[0m eta [36m0:00:00[0m [?25h Preparing metadata (setup.py) ... [?25l[?25hdone [2K [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m177.4/177.4 KB[0m [31m20.6 MB/s[0m eta [36m0:00:00[0m [2K [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m132.0/132.0 KB[0m [31m16.9 MB/s[0m eta [36m0:00:00[0m [2K [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m213.0/213.0 KB[0m [31m25.0 MB/s[0m eta [36m0:00:00[0m [2K [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.7/4.7 MB[0m [31m92.2 MB/s[0m eta [36m0:00:00[0m [2K [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m164.8/164.8 KB[0m [31m22.0 MB/s[0m eta [36m0:00:00[0m [2K [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m79.0/79.0 KB[0m [31m11.3 MB/s[0m eta [36m0:00:00[0m [2K [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m238.9/238.9 KB[0m [31m26.9 MB/s[0m eta [36m0:00:00[0m [?25h Preparing metadata (setup.py) ... [?25l[?25hdone [2K [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.7/62.7 KB[0m [31m8.7 MB/s[0m eta [36m0:00:00[0m [2K [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m140.6/140.6 KB[0m [31m18.4 MB/s[0m eta [36m0:00:00[0m [2K [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.5/84.5 KB[0m [31m12.7 MB/s[0m eta [36m0:00:00[0m [?25h Building wheel for seqeval (setup.py) ... [?25l[?25hdone Building wheel for validators (setup.py) ... [?25l[?25hdone Building wheel for pathtools (setup.py) ... [?25l[?25hdone
import os
import glob
import json
from simpletransformers.classification import (
MultiLabelClassificationModel, MultiLabelClassificationArgs
)
import pandas as pd
from sklearn.model_selection import train_test_split
!git clone https://github.com/YiweiJiang2015/CookDial.git
Cloning into 'CookDial'... remote: Enumerating objects: 596, done.[K remote: Counting objects: 100% (596/596), done.[K remote: Compressing objects: 100% (89/89), done.[K remote: Total 596 (delta 507), reused 585 (delta 501), pack-reused 0[K Receiving objects: 100% (596/596), 1.23 MiB | 23.30 MiB/s, done. Resolving deltas: 100% (507/507), done.
# 1st pass - get all possible intents, create translation dict
all_intents = set()
for file in list(glob.glob('CookDial/data/dialog/*.json')):
with open(file, encoding='utf-8') as dial_file:
dial_data = json.load(dial_file)
for message in dial_data['messages']:
if message['bot'] == False:
intents = json.loads(message['annotations'])['intent']
for intent in [intent.strip() for intent in intents.split(';')]:
if intent != '':
all_intents.add(intent)
intent2int = dict(zip(sorted(list(all_intents)), range(len(all_intents))))
intent2int
{'affirm': 0, 'confirm': 1, 'goodbye': 2, 'greeting': 3, 'negate': 4, 'other': 5, 'req_amount': 6, 'req_duration': 7, 'req_ingredient': 8, 'req_ingredient_list': 9, 'req_ingredient_list_ends': 10, 'req_ingredient_list_length': 11, 'req_instruction': 12, 'req_is_recipe_finished': 13, 'req_is_recipe_ongoing': 14, 'req_parallel_action': 15, 'req_repeat': 16, 'req_start': 17, 'req_substitute': 18, 'req_temperature': 19, 'req_title': 20, 'req_tool': 21, 'req_use_all': 22, 'thank': 23}
# 2nd pass - append utterance + intent multi-hot vectors to processed data
processed_data = []
for file in list(glob.glob('CookDial/data/dialog/*.json')):
with open(file, encoding='utf-8') as dial_file:
dial_data = json.load(dial_file)
for message in dial_data['messages']:
if message['bot'] == False:
annotations = json.loads(message['annotations'])
intents = [intent.strip() for intent in annotations['intent'].split(';')]
intents.remove('')
intents_multi_hot = [0] * len(all_intents)
for intent in intents:
intents_multi_hot[intent2int[intent]] = 1
processed_data.append([message['utterance'], intents_multi_hot])
processed_data[:5]
[['Hi what are we making today?', [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0]], ['Yes. What are the first two ingredients?', [1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], ['Ok. I have them. What are the next ingredients?', [0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], ['Ok. I have them as well. What are the next?', [0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], ['Ok. what kind of milk should I use?', [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]]
processed_data_df = pd.DataFrame(processed_data)
train_data, test_data = train_test_split(processed_data_df, test_size=0.1, random_state=42, shuffle=True)
# Very simple model definition, optimal epoch number is something between 5 and 10
model_args = MultiLabelClassificationArgs(num_train_epochs=5, overwrite_output_dir=True)
model = MultiLabelClassificationModel(
'roberta',
'roberta-base',
num_labels=len(all_intents),
args=model_args,
)
# Very simple training (run with GPU runtime!)
model.train_model(train_data)
# Very simple evaluation
result, model_outputs, wrong_predictions = model.eval_model(
test_data
)
# Evaluation returns LRAP score - "The obtained score is always strictly greater than 0 and the best value is 1"
# From: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.label_ranking_average_precision_score.html
print(result)
Downloading: 0%| | 0.00/481 [00:00<?, ?B/s]
Downloading: 0%| | 0.00/501M [00:00<?, ?B/s]
Some weights of the model checkpoint at roberta-base were not used when initializing RobertaForMultiLabelSequenceClassification: ['lm_head.dense.weight', 'lm_head.bias', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight', 'lm_head.layer_norm.weight', 'lm_head.dense.bias'] - This IS expected if you are initializing RobertaForMultiLabelSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model). - This IS NOT expected if you are initializing RobertaForMultiLabelSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model). Some weights of RobertaForMultiLabelSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.out_proj.bias', 'classifier.out_proj.weight', 'classifier.dense.bias', 'classifier.dense.weight'] You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Downloading: 0%| | 0.00/899k [00:00<?, ?B/s]
Downloading: 0%| | 0.00/456k [00:00<?, ?B/s]
/usr/local/lib/python3.8/dist-packages/simpletransformers/classification/classification_model.py:612: UserWarning: Dataframe headers not specified. Falling back to using column 0 as text and column 1 as labels. warnings.warn(
0%| | 0/4149 [00:00<?, ?it/s]
Epoch: 0%| | 0/5 [00:00<?, ?it/s]
Running Epoch 0 of 5: 0%| | 0/519 [00:00<?, ?it/s]
Running Epoch 1 of 5: 0%| | 0/519 [00:00<?, ?it/s]
Running Epoch 2 of 5: 0%| | 0/519 [00:00<?, ?it/s]
Running Epoch 3 of 5: 0%| | 0/519 [00:00<?, ?it/s]
Running Epoch 4 of 5: 0%| | 0/519 [00:00<?, ?it/s]
/usr/local/lib/python3.8/dist-packages/simpletransformers/classification/classification_model.py:1454: UserWarning: Dataframe headers not specified. Falling back to using column 0 as text and column 1 as labels. warnings.warn(
0%| | 0/461 [00:00<?, ?it/s]
Running Evaluation: 0%| | 0/58 [00:00<?, ?it/s]
{'LRAP': 0.932345722168039, 'eval_loss': 0.044855168964392664}
# Only this class and the trained model directory ("checkpoint" folders inside not needed) have to be copied into the chatbot project
class NLU:
def __init__(self):
self.int2intent = {0: 'affirm', 1: 'confirm', 2: 'goodbye', 3: 'greeting',
4: 'negate', 5: 'other', 6: 'req_amount', 7: 'req_duration',
8: 'req_ingredient', 9: 'req_ingredient_list', 10: 'req_ingredient_list_ends',
11: 'req_ingredient_list_length', 12: 'req_instruction',
13: 'req_is_recipe_finished', 14: 'req_is_recipe_ongoing',
15: 'req_parallel_action', 16: 'req_repeat', 17: 'req_start',
18: 'req_substitute', 19: 'req_temperature', 20: 'req_title',
21: 'req_tool', 22: 'req_use_all', 23: 'thank'}
# 2nd argument is the directory containing the trained model
self.model = MultiLabelClassificationModel('roberta', 'outputs', num_labels=len(self.int2intent))
def predict(self, utterance):
predictions_vector, raw_outputs = self.model.predict([utterance])
predictions_vector = predictions_vector[0]
predicted_intents = []
for i in range(len(predictions_vector)):
if predictions_vector[i] == 1:
predicted_intents.append(self.int2intent[i])
return predicted_intents
nlu = NLU()
nlu.predict('Hi, what do we cook today? Recommend me something.')
0%| | 0/1 [00:00<?, ?it/s]
0%| | 0/1 [00:00<?, ?it/s]
['greeting', 'req_title']
nlu.predict('Continue, please.')
0%| | 0/1 [00:00<?, ?it/s]
0%| | 0/1 [00:00<?, ?it/s]
['req_instruction']
nlu.predict('The chicken is golden brown now. What do I need to do?')
0%| | 0/1 [00:00<?, ?it/s]
0%| | 0/1 [00:00<?, ?it/s]
['confirm', 'req_instruction']
nlu.predict('What are the next ingredients?')
0%| | 0/1 [00:00<?, ?it/s]
0%| | 0/1 [00:00<?, ?it/s]
['req_ingredient']
nlu.predict("I didn't understand. Can you repeat this step?")
0%| | 0/1 [00:00<?, ?it/s]
0%| | 0/1 [00:00<?, ?it/s]
['req_repeat']