624 lines
21 KiB
Plaintext
624 lines
21 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"# Zadania z laboratoriów 2\n",
|
|
"\n",
|
|
"## Zadanie 1\n",
|
|
"Znajdź 2 przykłady (słowa, zdania) gdzie **zauważalne** są różnice pomiędzy tokenizerem BERT oraz RoBERTa"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 2,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from transformers import BertTokenizer, RobertaTokenizer, PreTrainedTokenizerFast, AutoTokenizer\n",
|
|
"\n",
|
|
"bert_tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')\n",
|
|
"roberta_tokenizer = RobertaTokenizer.from_pretrained('roberta-base')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"marion\n",
|
|
"Mar ion\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"text_en = 'Marion' #imię\n",
|
|
"\n",
|
|
"print(' '.join(bert_tokenizer.tokenize(text_en)))\n",
|
|
"print(' '.join(roberta_tokenizer.tokenize(text_en)))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"baptist\n",
|
|
"b apt ist\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"text_en = 'baptist' #baptysta\n",
|
|
"\n",
|
|
"print(' '.join(bert_tokenizer.tokenize(text_en)))\n",
|
|
"print(' '.join(roberta_tokenizer.tokenize(text_en)))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Zadanie 2\n",
|
|
"Znajdź 2 przykłady (słowa, zdania) gdzie podobne są wyniki pomiędzy tokenizerem BERT oraz RoBERTa"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"football\n",
|
|
"Football\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"text_en = 'Football'\n",
|
|
"\n",
|
|
"print(' '.join(bert_tokenizer.tokenize(text_en)))\n",
|
|
"print(' '.join(roberta_tokenizer.tokenize(text_en)))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 11,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"i like reading .\n",
|
|
"I Ġlike Ġreading .\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"text_en = 'I like reading.'\n",
|
|
"\n",
|
|
"print(' '.join(bert_tokenizer.tokenize(text_en)))\n",
|
|
"print(' '.join(roberta_tokenizer.tokenize(text_en)))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Zadanie 3\n",
|
|
"Sprawdź jak zachowa się tokenizer BERT/RoBERTa na innym języka niż Angielski"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 13,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"bard ##zo lu ##bie inform ##at ##yk ##e .\n",
|
|
"B ard zo Ġl ubi Ä Ļ Ġinform at yk Ä Ļ .\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"text_pl = 'Bardzo lubię informatykę.'\n",
|
|
"\n",
|
|
"#Tokenizacja na modelu z języka angielskiego\n",
|
|
"print(' '.join(bert_tokenizer.tokenize(text_pl)))\n",
|
|
"print(' '.join(roberta_tokenizer.tokenize(text_pl)))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Zadanie 4\n",
|
|
"Sprawdź jak zachowa się tokenizer BERT/RoBERTy na tekście medycznym, czy innym specjalistycznym tekście."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 14,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"when the exclude ##r , end ##ura ##nt , and zenith were poole ##d the rate of abdominal ao ##rti ##c an ##eur ##ys ##m ru ##pt ##ure was observed to be significantly higher among patients with the early af ##x .\n",
|
|
"When Ġthe ĠEx clud er , ĠEnd ur ant , Ġand ĠZen ith Ġwere Ġpooled Ġthe Ġrate Ġof Ġabdominal Ġa ort ic Ġan eur ys m Ġrupture Ġwas Ġobserved Ġto Ġbe Ġsignificantly Ġhigher Ġamong Ġpatients Ġwith Ġthe Ġearly ĠAF X .\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"# Tekst z artykułu medycznego\n",
|
|
"medical_en = 'When the Excluder, Endurant, and Zenith were pooled the rate of abdominal aortic aneurysm rupture was observed to be significantly higher among patients with the early AFX.'\n",
|
|
"\n",
|
|
"print(' '.join(bert_tokenizer.tokenize(medical_en)))\n",
|
|
"print(' '.join(roberta_tokenizer.tokenize(medical_en)))"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Zadanie 5\n",
|
|
"Wykonaj po 3 przykłady *FillMask* dla modelu:\n",
|
|
"- BERT/RoBERTa\n",
|
|
"- Polish RoBERTa"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### BERT - angielski"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.bias', 'cls.seq_relationship.weight']\n",
|
|
"- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
|
|
"- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"import torch\n",
|
|
"from torch.nn import functional as F\n",
|
|
"from transformers import BertForMaskedLM\n",
|
|
"\n",
|
|
"bert_model = BertForMaskedLM.from_pretrained('bert-base-uncased')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 18,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
" 0\tstars 0.6143525838851929\n",
|
|
" 1\tclouds 0.2152138501405716\n",
|
|
" 2\tbirds 0.008692129515111446\n",
|
|
" 3\tblue 0.008089331910014153\n",
|
|
" 4\tcloud 0.005828939378261566\n",
|
|
" 5\tsunshine 0.005086773540824652\n",
|
|
" 6\tlight 0.005068401340395212\n",
|
|
" 7\tflowers 0.004763070959597826\n",
|
|
" 8\tdarkness 0.004391019232571125\n",
|
|
" 9\tlights 0.004141420125961304\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"inputs_mlm = bert_tokenizer(f'The sky was full of {bert_tokenizer.mask_token}.', return_tensors='pt')\n",
|
|
"labels_mlm = bert_tokenizer(\"The sky was full of stars.\", return_tensors=\"pt\")[\"input_ids\"]\n",
|
|
"\n",
|
|
"outputs_mlm = bert_model(**inputs_mlm, labels=labels_mlm)\n",
|
|
"\n",
|
|
"mask_token_idx = 6 # CLS + 5 tokens\n",
|
|
"softmax_mlm = F.softmax(outputs_mlm.logits, dim = -1)\n",
|
|
"\n",
|
|
"mask_token = softmax_mlm[0, mask_token_idx, :]\n",
|
|
"top_10 = torch.topk(mask_token, 10, dim = 0)\n",
|
|
"\n",
|
|
"for i, (token_id, prob) in enumerate(zip(top_10.indices, top_10.values)):\n",
|
|
" token = bert_tokenizer.decode([token_id])\n",
|
|
" print(f'{i:2}\\t{token:25}', prob.item())"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 19,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
" 0\ttight 0.2341388612985611\n",
|
|
" 1\tbig 0.11350443959236145\n",
|
|
" 2\theavy 0.07258473336696625\n",
|
|
" 3\tshort 0.05406404659152031\n",
|
|
" 4\tlong 0.050229042768478394\n",
|
|
" 5\tlight 0.03884173184633255\n",
|
|
" 6\tthin 0.025743598118424416\n",
|
|
" 7\trevealing 0.020789707079529762\n",
|
|
" 8\twarm 0.01982339844107628\n",
|
|
" 9\tsmall 0.019418802112340927\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"inputs_mlm = bert_tokenizer(f'This jacket is a little too {bert_tokenizer.mask_token}.', return_tensors='pt')\n",
|
|
"labels_mlm = bert_tokenizer(\"This jacket is a little too big.\", return_tensors=\"pt\")[\"input_ids\"]\n",
|
|
"\n",
|
|
"outputs_mlm = bert_model(**inputs_mlm, labels=labels_mlm)\n",
|
|
"\n",
|
|
"mask_token_idx = 7 # CLS + 6 tokens\n",
|
|
"softmax_mlm = F.softmax(outputs_mlm.logits, dim = -1)\n",
|
|
"\n",
|
|
"mask_token = softmax_mlm[0, mask_token_idx, :]\n",
|
|
"top_10 = torch.topk(mask_token, 10, dim = 0)\n",
|
|
"\n",
|
|
"for i, (token_id, prob) in enumerate(zip(top_10.indices, top_10.values)):\n",
|
|
" token = bert_tokenizer.decode([token_id])\n",
|
|
" print(f'{i:2}\\t{token:25}', prob.item())"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 23,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"['what', \"'\", 's', 'your', 'favorite', 'ice', 'cream', '[MASK]', '?']\n",
|
|
" 0\tflavor 0.5929659008979797\n",
|
|
" 1\tnow 0.014950926415622234\n",
|
|
" 2\tline 0.014521223492920399\n",
|
|
" 3\trecipe 0.013670633547008038\n",
|
|
" 4\tcolor 0.010578353889286518\n",
|
|
" 5\t? 0.00849001295864582\n",
|
|
" 6\tthing 0.00799252837896347\n",
|
|
" 7\tplease 0.007873623631894588\n",
|
|
" 8\ttoday 0.007739454973489046\n",
|
|
" 9\tnumber 0.007451422978192568\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"inputs_mlm = bert_tokenizer(f\"What's your favorite ice cream {bert_tokenizer.mask_token}?\", return_tensors='pt')\n",
|
|
"labels_mlm = bert_tokenizer(\"What's your favorite ice cream flavor?\", return_tensors=\"pt\")[\"input_ids\"]\n",
|
|
"print(bert_tokenizer.tokenize(f\"What's your favorite ice cream {bert_tokenizer.mask_token}?\"))\n",
|
|
"\n",
|
|
"outputs_mlm = bert_model(**inputs_mlm, labels=labels_mlm)\n",
|
|
"\n",
|
|
"mask_token_idx = 8 # CLS + 7 tokens\n",
|
|
"softmax_mlm = F.softmax(outputs_mlm.logits, dim = -1)\n",
|
|
"\n",
|
|
"mask_token = softmax_mlm[0, mask_token_idx, :]\n",
|
|
"top_10 = torch.topk(mask_token, 10, dim = 0)\n",
|
|
"\n",
|
|
"for i, (token_id, prob) in enumerate(zip(top_10.indices, top_10.values)):\n",
|
|
" token = bert_tokenizer.decode([token_id])\n",
|
|
" print(f'{i:2}\\t{token:25}', prob.item())"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### RoBERTa - angielski"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"from transformers import RobertaForMaskedLM\n",
|
|
"\n",
|
|
"roberta_model = RobertaForMaskedLM.from_pretrained('roberta-base')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"['Hand', 'Ġme', 'Ġthe', '<mask>', '!']\n",
|
|
" 0\t keys 0.33524537086486816\n",
|
|
" 1\t phone 0.05494626611471176\n",
|
|
" 2\t key 0.02826027013361454\n",
|
|
" 3\t paper 0.025939658284187317\n",
|
|
" 4\t papers 0.01922498270869255\n",
|
|
" 5\t reins 0.018558315932750702\n",
|
|
" 6\t cup 0.016417579725384712\n",
|
|
" 7\t bag 0.015210084617137909\n",
|
|
" 8\t coffee 0.014366202056407928\n",
|
|
" 9\t gun 0.013706102967262268\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"inputs_mlm = roberta_tokenizer(f'Hand me the {roberta_tokenizer.mask_token}!', return_tensors='pt')\n",
|
|
"labels_mlm = roberta_tokenizer(\"Hand me the hammer!\", return_tensors=\"pt\")[\"input_ids\"]\n",
|
|
"print(roberta_tokenizer.tokenize(f'Hand me the {roberta_tokenizer.mask_token}!'))\n",
|
|
"\n",
|
|
"outputs_mlm = roberta_model(**inputs_mlm, labels=labels_mlm)\n",
|
|
"\n",
|
|
"mask_token_idx = 4 # CLS + 3 tokens\n",
|
|
"softmax_mlm = F.softmax(outputs_mlm.logits, dim = -1)\n",
|
|
"\n",
|
|
"mask_token = softmax_mlm[0, mask_token_idx, :]\n",
|
|
"top_10 = torch.topk(mask_token, 10, dim = 0)\n",
|
|
"\n",
|
|
"for i, (token_id, prob) in enumerate(zip(top_10.indices, top_10.values)):\n",
|
|
" token = roberta_tokenizer.decode([token_id])\n",
|
|
" print(f'{i:2}\\t{token:25}', prob.item())"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"### RoBERTa - polski"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 6,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. \n",
|
|
"The tokenizer class you load from this checkpoint is 'RobertaTokenizer'. \n",
|
|
"The class this function is called from is 'PreTrainedTokenizerFast'.\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from transformers import AutoModelForMaskedLM\n",
|
|
"\n",
|
|
"polish_roberta_tokenizer = PreTrainedTokenizerFast.from_pretrained('sdadas/polish-roberta-large-v1')\n",
|
|
"polish_roberta_model = AutoModelForMaskedLM.from_pretrained('sdadas/polish-roberta-large-v1')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 7,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"['Bar', 'dzo', '▁lubię', ' <mask>', '.']\n",
|
|
" 0\tczytać 0.06616953760385513\n",
|
|
" 1\tpodróżować 0.04533696547150612\n",
|
|
" 2\tgotować 0.04076462611556053\n",
|
|
" 3\tmuzykę 0.039369307458400726\n",
|
|
" 4\tkoty 0.03558063879609108\n",
|
|
" 5\tpisać 0.03538721054792404\n",
|
|
" 6\tksiążki 0.033440858125686646\n",
|
|
" 7\tśpiewać 0.02773296646773815\n",
|
|
" 8\tsport 0.027220433577895164\n",
|
|
" 9\ttańczyć 0.026598699390888214\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"inputs_mlm = polish_roberta_tokenizer(f'Bardzo lubię {polish_roberta_tokenizer.mask_token}.', return_tensors='pt')\n",
|
|
"labels_mlm = polish_roberta_tokenizer(\"Bardzo lubię czytać.\", return_tensors=\"pt\")[\"input_ids\"]\n",
|
|
"print(polish_roberta_tokenizer.tokenize(f'Bardzo lubię {polish_roberta_tokenizer.mask_token}.'))\n",
|
|
"\n",
|
|
"outputs_mlm = polish_roberta_model(**inputs_mlm, labels=labels_mlm)\n",
|
|
"\n",
|
|
"mask_token_idx = 4 # CLS + 3 tokens\n",
|
|
"softmax_mlm = F.softmax(outputs_mlm.logits, dim = -1)\n",
|
|
"\n",
|
|
"mask_token = softmax_mlm[0, mask_token_idx, :]\n",
|
|
"top_10 = torch.topk(mask_token, 10, dim = 0)\n",
|
|
"\n",
|
|
"for i, (token_id, prob) in enumerate(zip(top_10.indices, top_10.values)):\n",
|
|
" token = polish_roberta_tokenizer.decode([token_id])\n",
|
|
" print(f'{i:2}\\t{token:25}', prob.item())"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 9,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"['Za', 'jęcia', '▁na', '▁uczelni', '▁są', ' <mask>', '.']\n",
|
|
" 0\tbezpłatne 0.9145433902740479\n",
|
|
" 1\tobowiązkowe 0.014430041424930096\n",
|
|
" 2\tprowadzone 0.010215427726507187\n",
|
|
" 3\tzróżnicowane 0.008744887076318264\n",
|
|
" 4\tróżnorodne 0.00670977309346199\n",
|
|
" 5\tnastępujące 0.004183280747383833\n",
|
|
" 6\totwarte 0.002896391786634922\n",
|
|
" 7\tintensywne 0.002672090893611312\n",
|
|
" 8\trealizowane 0.0019869415555149317\n",
|
|
" 9\tok 0.0018993624253198504\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"inputs_mlm = polish_roberta_tokenizer(f'Zajęcia na uczelni są {polish_roberta_tokenizer.mask_token}.', return_tensors='pt')\n",
|
|
"labels_mlm = polish_roberta_tokenizer(\"Zajęcia na uczelni są ciekawe.\", return_tensors=\"pt\")[\"input_ids\"]\n",
|
|
"print(polish_roberta_tokenizer.tokenize(f'Zajęcia na uczelni są {polish_roberta_tokenizer.mask_token}.'))\n",
|
|
"\n",
|
|
"outputs_mlm = polish_roberta_model(**inputs_mlm, labels=labels_mlm)\n",
|
|
"\n",
|
|
"mask_token_idx = 6 # CLS + 5 tokens\n",
|
|
"softmax_mlm = F.softmax(outputs_mlm.logits, dim = -1)\n",
|
|
"\n",
|
|
"mask_token = softmax_mlm[0, mask_token_idx, :]\n",
|
|
"top_10 = torch.topk(mask_token, 10, dim = 0)\n",
|
|
"\n",
|
|
"for i, (token_id, prob) in enumerate(zip(top_10.indices, top_10.values)):\n",
|
|
" token = polish_roberta_tokenizer.decode([token_id])\n",
|
|
" print(f'{i:2}\\t{token:25}', prob.item())"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 14,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"['Ju', 'tro', '▁na', '▁obiad', '▁będzie', ' <mask>', '.']\n",
|
|
" 0\tryba 0.27743467688560486\n",
|
|
" 1\tmięso 0.1686241328716278\n",
|
|
" 2\tciasto 0.024455789476633072\n",
|
|
" 3\tryż 0.0164520051330328\n",
|
|
" 4\tniedziela 0.013327408581972122\n",
|
|
" 5\tmasło 0.01118378434330225\n",
|
|
" 6\tobiad 0.010521633550524712\n",
|
|
" 7\tchleb 0.00991259329020977\n",
|
|
" 8\tczwartek 0.009901482611894608\n",
|
|
" 9\twino 0.008945722132921219\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"inputs_mlm = polish_roberta_tokenizer(f'Jutro na obiad będzie {polish_roberta_tokenizer.mask_token}.', return_tensors='pt')\n",
|
|
"labels_mlm = polish_roberta_tokenizer(\"Jutro na obiad będzie ryba.\", return_tensors=\"pt\")[\"input_ids\"]\n",
|
|
"print(polish_roberta_tokenizer.tokenize(f'Jutro na obiad będzie {polish_roberta_tokenizer.mask_token}.'))\n",
|
|
"\n",
|
|
"outputs_mlm = polish_roberta_model(**inputs_mlm, labels=labels_mlm)\n",
|
|
"\n",
|
|
"mask_token_idx = 6 # CLS + 5 tokens\n",
|
|
"softmax_mlm = F.softmax(outputs_mlm.logits, dim = -1)\n",
|
|
"\n",
|
|
"mask_token = softmax_mlm[0, mask_token_idx, :]\n",
|
|
"top_10 = torch.topk(mask_token, 10, dim = 0)\n",
|
|
"\n",
|
|
"for i, (token_id, prob) in enumerate(zip(top_10.indices, top_10.values)):\n",
|
|
" token = polish_roberta_tokenizer.decode([token_id])\n",
|
|
" print(f'{i:2}\\t{token:25}', prob.item())"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "markdown",
|
|
"metadata": {},
|
|
"source": [
|
|
"## Zadanie 6\n",
|
|
"Spróbuj porównać czy jedno zdanie następuje po drugim."
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 5,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForNextSentencePrediction: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.bias']\n",
|
|
"- This IS expected if you are initializing BertForNextSentencePrediction from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
|
|
"- This IS NOT expected if you are initializing BertForNextSentencePrediction from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
|
|
]
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Kolejne zdanie jest losowe: False\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"from transformers import BertTokenizer, BertForNextSentencePrediction\n",
|
|
"import torch\n",
|
|
"\n",
|
|
"tokenizer = BertTokenizer.from_pretrained(\"bert-base-uncased\")\n",
|
|
"model = BertForNextSentencePrediction.from_pretrained(\"bert-base-uncased\")\n",
|
|
"\n",
|
|
"prompt = \"In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced.\"\n",
|
|
"next_sentence = \"In other cases pizza may be sliced.\"\n",
|
|
"encoding = tokenizer(prompt, next_sentence, return_tensors=\"pt\")\n",
|
|
"\n",
|
|
"outputs = model(**encoding, labels=torch.LongTensor([1]))\n",
|
|
"logits = outputs.logits\n",
|
|
"\n",
|
|
"sentenceWasRandom = logits[0, 0] < logits[0, 1]\n",
|
|
"print(\"Kolejne zdanie jest losowe: \" + str(sentenceWasRandom.item()))"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3.9.6 64-bit",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.9.6"
|
|
},
|
|
"orig_nbformat": 4,
|
|
"vscode": {
|
|
"interpreter": {
|
|
"hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
|
|
}
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 2
|
|
}
|