This commit is contained in:
Filip Gralinski 2021-09-27 07:36:37 +02:00
parent a45fd570e5
commit 72c6fbcbf6
2 changed files with 94 additions and 44 deletions

View File

@ -3,6 +3,22 @@
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [
"![Logo 1](https://git.wmi.amu.edu.pl/AITech/Szablon/raw/branch/master/Logotyp_AITech1.jpg)\n",
"<div class=\"alert alert-block alert-info\">\n",
"<h1>Ekstrakcja informacji</h1>\n",
"<h2>1. <i>Wyszukiwarki - wprowadzenie</i> [wykład]</h2> \n",
"<h3>Filip Graliński (2021)</h3>\n",
"</div>\n",
"\n",
"![Logo 2](https://git.wmi.amu.edu.pl/AITech/Szablon/raw/branch/master/Logotyp_AITech2.jpg)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [ "source": [
"# Wyszukiwarki - wprowadzenie\n", "# Wyszukiwarki - wprowadzenie\n",
"\n", "\n",
@ -13,7 +29,10 @@
}, },
{ {
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {
"jp-MarkdownHeadingCollapsed": true,
"tags": []
},
"source": [ "source": [
"## Wyszukiwarki\n", "## Wyszukiwarki\n",
"\n", "\n",
@ -1676,7 +1695,7 @@
], ],
"metadata": { "metadata": {
"kernelspec": { "kernelspec": {
"display_name": "Python 3", "display_name": "Python 3 (ipykernel)",
"language": "python", "language": "python",
"name": "python3" "name": "python3"
}, },
@ -1690,7 +1709,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.9.1" "version": "3.9.6"
} }
}, },
"nbformat": 4, "nbformat": 4,

View File

@ -24,7 +24,7 @@
"Jest kilka sposobów na pretrenowanie modelu, w każdym razie sprowadza\n", "Jest kilka sposobów na pretrenowanie modelu, w każdym razie sprowadza\n",
"się do odgadywania następnego bądź zamaskowanego słowa.\n", "się do odgadywania następnego bądź zamaskowanego słowa.\n",
"W każdym razie zawsze stosujemy softmax (być może ze „sztuczkami” takimi jak\n", "W każdym razie zawsze stosujemy softmax (być może ze „sztuczkami” takimi jak\n",
"negatywne próbkowanie albo hierarchiczny softamx) na pewnej **representecji kontekstowej**:\n", "negatywne próbkowanie albo hierarchiczny softamx) na pewnej **reprezentacji kontekstowej**:\n",
"\n", "\n",
"$$\\vec{p} = \\operatorname{softmax}(f(\\vec{c})).$$\n", "$$\\vec{p} = \\operatorname{softmax}(f(\\vec{c})).$$\n",
"\n", "\n",
@ -66,7 +66,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 5, "execution_count": 17,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{ {
@ -79,39 +79,39 @@
{ {
"data": { "data": {
"text/plain": [ "text/plain": [
"[('Ġon', 0.6786560416221619),\n", "[('Âł', 0.6182783842086792),\n",
" ('Ġupon', 0.04339785501360893),\n", " ('È', 0.1154019758105278),\n",
" ('Ġheavily', 0.02208443358540535),\n", " ('Ñģ', 0.026960616931319237),\n",
" ('Ġin', 0.021049050614237785),\n", " ('_____', 0.024418892338871956),\n",
" (',', 0.020188499242067337),\n", " ('________', 0.014962316490709782),\n",
" ('Ġa', 0.01833895780146122),\n", " ('ÃĤ', 0.010653386823832989),\n",
" ('Ġvery', 0.017935041338205338),\n", " ('ä¸Ń', 0.008340531960129738),\n",
" ('Ġentirely', 0.017528969794511795),\n", " ('Ñ', 0.007557711564004421),\n",
" ('Ġlargely', 0.016769640147686005),\n", " ('Ê', 0.007046067621558905),\n",
" ('Ġto', 0.01009418722242117),\n", " ('ãĢ', 0.006875576451420784),\n",
" ('Ġgreatly', 0.010009866207838058),\n", " ('ile', 0.006685272324830294),\n",
" ('Ġnot', 0.009016563184559345),\n", " ('____', 0.006307446397840977),\n",
" ('Ġmore', 0.005853226874023676),\n", " ('âĢĭ', 0.006306538358330727),\n",
" ('Ġprimarily', 0.005203146021813154),\n", " ('ÑĢ', 0.006197483278810978),\n",
" ('Ġstrongly', 0.0034501152113080025),\n", " ('ĠBelarus', 0.006108700763434172),\n",
" ('Ġpartly', 0.0033184229396283627),\n", " ('Æ', 0.005720408633351326),\n",
" ('Ġmuch', 0.0033095215912908316),\n", " ('ĠPoland', 0.0053678699769079685),\n",
" ('Ġmostly', 0.0032150144688785076),\n", " ('á¹', 0.004606408067047596),\n",
" ('Ġmainly', 0.0030899408739060163),\n", " ('îĢ', 0.004161055199801922),\n",
" ('Ġfor', 0.003034428460523486),\n", " ('????', 0.004056799225509167),\n",
" ('.', 0.0028878094162791967),\n", " ('_______', 0.0038176667876541615),\n",
" ('Ġboth', 0.0028405177872627974),\n", " ('ä¸', 0.0036082742735743523),\n",
" ('Ġsomewhat', 0.0028194624464958906),\n", " ('Ì', 0.003221835708245635),\n",
" ('Ġcru', 0.002263976726680994),\n", " ('urs', 0.003080119378864765),\n",
" ('Ġas', 0.00221616611815989),\n", " ('________________', 0.0027312245219945908),\n",
" ('Ġof', 0.0022000609897077084),\n", " ('ĠLithuania', 0.0023860156070441008),\n",
" ('Ġalmost', 0.001968063646927476),\n", " ('ich', 0.0021211160346865654),\n",
" ('Ġat', 0.0018015997484326363),\n", " ('iz', 0.002069818088784814),\n",
" ('Ġhighly', 0.0017461496172472835),\n", " ('vern', 0.002001357264816761),\n",
" ('Ġcompletely', 0.001692073536105454)]" " ('ÅĤ', 0.001717406208626926)]"
] ]
}, },
"execution_count": 5, "execution_count": 17,
"metadata": {}, "metadata": {},
"output_type": "execute_result" "output_type": "execute_result"
} }
@ -121,12 +121,11 @@
"from transformers import GPT2Tokenizer, GPT2LMHeadModel\n", "from transformers import GPT2Tokenizer, GPT2LMHeadModel\n",
"tokenizer = GPT2Tokenizer.from_pretrained('gpt2-large')\n", "tokenizer = GPT2Tokenizer.from_pretrained('gpt2-large')\n",
"model = GPT2LMHeadModel.from_pretrained('gpt2-large')\n", "model = GPT2LMHeadModel.from_pretrained('gpt2-large')\n",
"text = \"This issue depends\"\n", "text = 'Warsaw is the capital city of'\n",
"encoded_input = tokenizer(text, return_tensors='pt')\n", "encoded_input = tokenizer(text, return_tensors='pt')\n",
"output = model(**encoded_input)\n", "output = model(**encoded_input)\n",
"next_token_probs = torch.softmax(output[0][:, -1, :][0], dim=0)\n", "next_token_probs = torch.softmax(output[0][:, -1, :][0], dim=0)\n",
"\n", "\n",
"next_token_probs\n",
"nb_of_tokens = next_token_probs.size()[0]\n", "nb_of_tokens = next_token_probs.size()[0]\n",
"print(nb_of_tokens)\n", "print(nb_of_tokens)\n",
"\n", "\n",
@ -198,11 +197,28 @@
"execution_count": 1, "execution_count": 1,
"metadata": {}, "metadata": {},
"outputs": [ "outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/filipg/.local/lib/python3.9/site-packages/transformers/models/auto/modeling_auto.py:806: FutureWarning: The class `AutoModelWithLMHead` is deprecated and will be removed in a future version. Please use `AutoModelForCausalLM` for causal language models, `AutoModelForMaskedLM` for masked language models and `AutoModelForSeq2SeqLM` for encoder-decoder models.\n",
" warnings.warn(\n"
]
},
{ {
"name": "stdout", "name": "stdout",
"output_type": "stream", "output_type": "stream",
"text": [ "text": [
"# Out[3]:" "W którym państwie leży Bombaj? W USA. (score: 0.16715531051158905)\n",
"W którym państwie leży Bombaj? W India. (score: 0.09912960231304169)\n",
"W którym państwie leży Bombaj? W Indian. (score: 0.039642028510570526)\n",
"W którym państwie leży Bombaj? W Nepal. (score: 0.027137665078043938)\n",
"W którym państwie leży Bombaj? W Pakistan. (score: 0.027065709233283997)\n",
"W którym państwie leży Bombaj? W Polsce. (score: 0.023737527430057526)\n",
"W którym państwie leży Bombaj? W .... (score: 0.02306722290813923)\n",
"W którym państwie leży Bombaj? W Bangladesh. (score: 0.022106658667325974)\n",
"W którym państwie leży Bombaj? W .... (score: 0.01628892682492733)\n",
"W którym państwie leży Bombaj? W Niemczech. (score: 0.014501162804663181)\n"
] ]
} }
], ],
@ -213,7 +229,7 @@
"tokenizer = AutoTokenizer.from_pretrained(\"xlm-roberta-large\")\n", "tokenizer = AutoTokenizer.from_pretrained(\"xlm-roberta-large\")\n",
"model = AutoModelWithLMHead.from_pretrained(\"xlm-roberta-large\")\n", "model = AutoModelWithLMHead.from_pretrained(\"xlm-roberta-large\")\n",
"\n", "\n",
"sequence = f'II wojna światowa zakończyła się w {tokenizer.mask_token} roku.'\n", "sequence = f'W którym państwie leży Bombaj? W {tokenizer.mask_token}.'\n",
"\n", "\n",
"input_ids = tokenizer.encode(sequence, return_tensors=\"pt\")\n", "input_ids = tokenizer.encode(sequence, return_tensors=\"pt\")\n",
"mask_token_index = torch.where(input_ids == tokenizer.mask_token_id)[1]\n", "mask_token_index = torch.where(input_ids == tokenizer.mask_token_id)[1]\n",
@ -262,9 +278,24 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 1, "execution_count": 2,
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [
{
"data": {
"text/plain": [
"['World War II ended in World War II.',\n",
" 'World War II ended in 1945..',\n",
" 'World War II ended in 1945.',\n",
" 'World War II ended in 1945.',\n",
" 'World War II ended in 1945.']"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [ "source": [
"from transformers import T5Tokenizer, T5Config, T5ForConditionalGeneration\n", "from transformers import T5Tokenizer, T5Config, T5ForConditionalGeneration\n",
"\n", "\n",
@ -276,7 +307,7 @@
"\n", "\n",
"slot = '<extra_id_0>'\n", "slot = '<extra_id_0>'\n",
"\n", "\n",
"text = f'Warsaw is the {slot} of Poland.'\n", "text = f'World War II ended in {slot}.'\n",
"\n", "\n",
"encoded = t5_tokenizer.encode_plus(text, add_special_tokens=True, return_tensors='pt')\n", "encoded = t5_tokenizer.encode_plus(text, add_special_tokens=True, return_tensors='pt')\n",
"input_ids = encoded['input_ids']\n", "input_ids = encoded['input_ids']\n",
@ -334,5 +365,5 @@
"org": null "org": null
}, },
"nbformat": 4, "nbformat": 4,
"nbformat_minor": 1 "nbformat_minor": 4
} }