aitech-moj-2023/cw/12_Ensemble_modeli.ipynb
Jakub Pokrywka 85d14a1c10 update
2022-07-05 11:24:56 +02:00

335 lines
11 KiB
Plaintext
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"![Logo 1](https://git.wmi.amu.edu.pl/AITech/Szablon/raw/branch/master/Logotyp_AITech1.jpg)\n",
"<div class=\"alert alert-block alert-info\">\n",
"<h1> Modelowanie Języka</h1>\n",
"<h2> 12. <i>Model neuronowy rekurencyjny</i> [ćwiczenia]</h2> \n",
"<h3> Jakub Pokrywka (2022)</h3>\n",
"</div>\n",
"\n",
"![Logo 2](https://git.wmi.amu.edu.pl/AITech/Szablon/raw/branch/master/Logotyp_AITech2.jpg)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Ensemble modeli"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"W jaki sposób można podnieść wynik predykcji dla zadania uczenia maszynowego?"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"mkdir: cannot create directory dev-0-ireland-news: File exists\r\n"
]
}
],
"source": [
"!mkdir dev-0-ireland-news"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Mamy wyzwanie:"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"https://gonito.net/challenge-all-submissions/ireland-news-headlines-word-gap"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"https://github.com/kubapok/ireland-news-word-gap/tree/0c6557c8a3cd6d8c77f64618850b2ae82c19476a"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"--2022-05-13 13:23:05-- https://github.com/kubapok/ireland-news-word-gap/raw/11c72875023c5c01c9d0c0ca39d72c90c840aeb3/dev-0/out.tsv\n",
"Resolving github.com (github.com)... 140.82.121.4\n",
"Connecting to github.com (github.com)|140.82.121.4|:443... connected.\n",
"HTTP request sent, awaiting response... 302 Found\n",
"Location: https://raw.githubusercontent.com/kubapok/ireland-news-word-gap/11c72875023c5c01c9d0c0ca39d72c90c840aeb3/dev-0/out.tsv [following]\n",
"--2022-05-13 13:23:06-- https://raw.githubusercontent.com/kubapok/ireland-news-word-gap/11c72875023c5c01c9d0c0ca39d72c90c840aeb3/dev-0/out.tsv\n",
"Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.111.133, 185.199.108.133, ...\n",
"Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 63249692 (60M) [text/plain]\n",
"Saving to: out.tsv\n",
"\n",
"out.tsv 100%[===================>] 60,32M 26,7MB/s in 2,3s \n",
"\n",
"2022-05-13 13:23:08 (26,7 MB/s) - out.tsv saved [63249692/63249692]\n",
"\n",
"--2022-05-13 13:23:09-- https://github.com/kubapok/ireland-news-word-gap/raw/0c6557c8a3cd6d8c77f64618850b2ae82c19476a/dev-0/out.tsv\n",
"Resolving github.com (github.com)... 140.82.121.4\n",
"Connecting to github.com (github.com)|140.82.121.4|:443... connected.\n",
"HTTP request sent, awaiting response... 302 Found\n",
"Location: https://raw.githubusercontent.com/kubapok/ireland-news-word-gap/0c6557c8a3cd6d8c77f64618850b2ae82c19476a/dev-0/out.tsv [following]\n",
"--2022-05-13 13:23:09-- https://raw.githubusercontent.com/kubapok/ireland-news-word-gap/0c6557c8a3cd6d8c77f64618850b2ae82c19476a/dev-0/out.tsv\n",
"Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.111.133, 185.199.108.133, ...\n",
"Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 63271863 (60M) [text/plain]\n",
"Saving to: out.tsv\n",
"\n",
"out.tsv 100%[===================>] 60,34M 45,1MB/s in 1,3s \n",
"\n",
"2022-05-13 13:23:10 (45,1 MB/s) - out.tsv saved [63271863/63271863]\n",
"\n",
"--2022-05-13 13:23:11-- https://git.wmi.amu.edu.pl/kubapok/ireland-news-word-gap-prediction/raw/branch/master/dev-0/expected.tsv\n",
"Resolving git.wmi.amu.edu.pl (git.wmi.amu.edu.pl)... 150.254.78.40\n",
"Connecting to git.wmi.amu.edu.pl (git.wmi.amu.edu.pl)|150.254.78.40|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 866583 (846K) [text/plain]\n",
"Saving to: expected.tsv.1\n",
"\n",
"expected.tsv.1 100%[===================>] 846,27K 1,91MB/s in 0,4s \n",
"\n",
"2022-05-13 13:23:11 (1,91 MB/s) - expected.tsv.1 saved [866583/866583]\n",
"\n"
]
}
],
"source": [
"!wget https://github.com/kubapok/ireland-news-word-gap/raw/11c72875023c5c01c9d0c0ca39d72c90c840aeb3/dev-0/out.tsv\n",
"!mv out.tsv ./dev-0/out-solution1.tsv\n",
"!wget https://github.com/kubapok/ireland-news-word-gap/raw/0c6557c8a3cd6d8c77f64618850b2ae82c19476a/dev-0/out.tsv\n",
"!mv out.tsv ./dev-0/out-solution2.tsv\n",
"! ( cd dev-0 ; wget https://git.wmi.amu.edu.pl/kubapok/ireland-news-word-gap-prediction/raw/branch/master/dev-0/expected.tsv)"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"--2022-05-13 13:23:12-- https://gonito.net/get/bin/geval\n",
"Resolving gonito.net (gonito.net)... 150.254.78.126\n",
"Connecting to gonito.net (gonito.net)|150.254.78.126|:443... connected.\n",
"HTTP request sent, awaiting response... 200 OK\n",
"Length: 12860136 (12M) [application/octet-stream]\n",
"Saving to: geval.1\n",
"\n",
"geval.1 100%[===================>] 12,26M 2,67MB/s in 4,1s \n",
"\n",
"2022-05-13 13:23:16 (2,97 MB/s) - geval.1 saved [12860136/12860136]\n",
"\n"
]
}
],
"source": [
"!wget https://gonito.net/get/bin/geval\n",
"!chmod u+x geval"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"35.05218788086649\r\n"
]
}
],
"source": [
"!./geval --metric PerplexityHashed -o ./dev-0/out-solution1.tsv -e dev-0/expected.tsv"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"33.47429048442195\r\n"
]
}
],
"source": [
"!./geval --metric PerplexityHashed -o ./dev-0/out-solution2.tsv -e dev-0/expected.tsv"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [],
"source": [
"with open('./dev-0/out-solution1.tsv') as s1, open('./dev-0/out-solution2.tsv') as s2, open('./dev-0/out-merge.tsv','w') as f_merge:\n",
" for l1, l2 in zip(s1, s2):\n",
" dir1 = {''.join(x.split(':')[:-1]): float(x.split(':')[-1]) for x in l1.rstrip().split(' ')}\n",
" dir2 = {''.join(x.split(':')[:-1]): float(x.split(':')[-1]) for x in l2.rstrip().split(' ')}\n",
" newdir = dict()\n",
" for k in dir1.keys() | dir2.keys():\n",
" newdir[k] = dir1[k] if k in dir1 else 0.0\n",
" newdir[k] += dir2[k] if k in dir2 else 0.0\n",
" newdir[k] /= 2\n",
" merge_line = ' '.join([k + ':' + str(v) for k,v in newdir.items()]) + '\\n'\n",
" f_merge.write(merge_line)"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"29.054162509715063\r\n"
]
}
],
"source": [
"!./geval --metric PerplexityHashed -o ./dev-0/out-merge.tsv -e dev-0/expected.tsv"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Złożenie:\n",
"\n",
"- kilku dobrych niezależnych modeli \n",
"- kilku modeli wytrenowanych dla różnego seeda\n",
"- kilku ostatnich checkpointów"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"W jaki sposób składać różne modele:\n",
"\n",
"- średnia ważona\n",
"- średnia geometryczna\n",
"- inna średnia\n",
"- trenowanie osobnego prostego modelu, którego zadanie to składanie modeli (np. regresja liniowa)\n",
"\n",
"\n",
"Można też trenować wspólnie kilka modeli jednocześnie ze wspólnym backpropagation."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Jakie są minusy ensemble?\n",
"\n",
"- wyższy stopień skomplikowania modelu\n",
"- dłuższy czas inferencji\n",
"- zużycie większych zasobów komputera\n",
"- gorsza interpretowalność"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"W praktyce jeżeli startujemy w konkursie uczenia maszynowego, zawsze warto robić ensemble.\n",
"\n",
"\n",
"W komercji jeżeli mamy ograniczenia czasowe lub zasobów, model jest ciężki, wynik modelu nie jest bardzo ważny, to często nie korzysta się z ensembli.\n",
"W nauce albo kiedy chcemy porównać kilka różnych metod, to składanie modeli zaburza nam niepotrzebnie obraz."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Warto mieć na uwadze, że niektóre metody z założenia są ensemblami. Np. las losowy albo boostowane drzewa decyzyjne."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### ZADANIE\n",
"\n",
"przykładowy tekst: \"ala\" \"ma\" \"kota\" \"MASK\" \"2\" \"psy\" \"i\" \"chomika\"\n",
"\n",
"Stworzyć 2 sieci rekurencyjne LSTM dla Challenging America word-gap prediction. \n",
"- jedna sieć powinna działać do przodu (czyli jak zwyczajna sieć): \"ala\" \"ma\" \"kota\"\n",
"- druga siec powinna działać do tyłu: \"chomika\" \"i\" \"psy\" \"2\"\n",
" \n",
"Zrobienie sieci odwrotnej jest bardzo proste. Wystarczy odwrócić kolejność słów, nie ma potrzeby ingerować w architekturę modeli.\n",
"\n",
"Następnie należy zrobić jakiś ensemble tych modeli. W najprostszej wersji może to być średnia arytmetyczna. Warto jednak spróbować innych sposobów, np. można trenować te sieci łącznie.\n",
"\n",
"\n",
"Wymogi takie jak zawsze, zadanie widoczne na gonito."
]
}
],
"metadata": {
"author": "Jakub Pokrywka",
"email": "kubapok@wmi.amu.edu.pl",
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"lang": "pl",
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.3"
},
"subtitle": "0.Informacje na temat przedmiotu[ćwiczenia]",
"title": "Ekstrakcja informacji",
"year": "2021"
},
"nbformat": 4,
"nbformat_minor": 4
}