{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## importy" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/media/kuba/ssdsam/anaconda3/lib/python3.8/site-packages/gensim/similarities/__init__.py:15: UserWarning: The gensim.similarities.levenshtein submodule is disabled, because the optional Levenshtein package is unavailable. Install Levenhstein (e.g. `pip install python-Levenshtein`) to suppress this warning.\n", " warnings.warn(msg)\n" ] } ], "source": [ "from gensim.utils import tokenize\n", "import numpy as np\n", "import torch\n", "from tqdm.notebook import tqdm" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "#device = 'cpu'" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using cuda device\n" ] } ], "source": [ "print('Using {} device'.format(device))" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "device(type='cuda')" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "device" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## przygotowanie zbiorów" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "pan_tadeusz_path_train= '/home/kuba/Syncthing/przedmioty/2020-02/ISI/zajecia9_ngramowy_model_jDDezykowy/pan-tadeusz-train.txt'" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "pan_tadeusz_path_valid= '/home/kuba/Syncthing/przedmioty/2020-02/ISI/zajecia9_ngramowy_model_jDDezykowy/pan-tadeusz-test.txt'" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "corpora_train = open(pan_tadeusz_path_train).read()" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "corpora_train_tokenized = list(tokenize(corpora_train,lowercase = True))" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "vocab_itos = sorted(set(corpora_train_tokenized))" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "16598" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(vocab_itos)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "vocab_itos = vocab_itos[:15005]\n", "vocab_itos[15001] = \"\"\n", "vocab_itos[15002] = \"\"\n", "vocab_itos[15003] = \"\"\n", "vocab_itos[15004] = \"\"" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "15005" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(vocab_itos)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "vocab_stoi = dict()\n", "for i, token in enumerate(vocab_itos):\n", " vocab_stoi[token] = i" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "NGRAMS = 5" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "def get_token_id(dataset):\n", " token_ids = [vocab_stoi['']] * (NGRAMS-1) + [vocab_stoi['']]\n", " for token in dataset:\n", " try:\n", " token_ids.append(vocab_stoi[token])\n", " except KeyError:\n", " token_ids.append(vocab_stoi[''])\n", " token_ids.append(vocab_stoi[''])\n", " return token_ids" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "train_ids = get_token_id(corpora_train_tokenized)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "[15004,\n", " 15004,\n", " 15004,\n", " 15004,\n", " 15002,\n", " 7,\n", " 5002,\n", " 7247,\n", " 11955,\n", " 1432,\n", " 7018,\n", " 14739,\n", " 5506,\n", " 4696,\n", " 4276,\n", " 7505,\n", " 2642,\n", " 8477,\n", " 7259,\n", " 10870,\n", " 10530,\n", " 7506,\n", " 12968,\n", " 7997,\n", " 1911,\n", " 12479,\n", " 11129,\n", " 13069,\n", " 11797,\n", " 5819]" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_ids[:30]" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "def get_samples(dataset):\n", " samples = []\n", " for i in range(len(dataset)-NGRAMS):\n", " samples.append(dataset[i:i+NGRAMS])\n", " return samples" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "train_ids = get_samples(train_ids)" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "train_ids = torch.tensor(train_ids, device = device)" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "tensor([[15004, 15004, 15004, 15004, 15002],\n", " [15004, 15004, 15004, 15002, 7],\n", " [15004, 15004, 15002, 7, 5002],\n", " [15004, 15002, 7, 5002, 7247],\n", " [15002, 7, 5002, 7247, 11955],\n", " [ 7, 5002, 7247, 11955, 1432],\n", " [ 5002, 7247, 11955, 1432, 7018],\n", " [ 7247, 11955, 1432, 7018, 14739],\n", " [11955, 1432, 7018, 14739, 5506],\n", " [ 1432, 7018, 14739, 5506, 4696],\n", " [ 7018, 14739, 5506, 4696, 4276],\n", " [14739, 5506, 4696, 4276, 7505],\n", " [ 5506, 4696, 4276, 7505, 2642],\n", " [ 4696, 4276, 7505, 2642, 8477],\n", " [ 4276, 7505, 2642, 8477, 7259],\n", " [ 7505, 2642, 8477, 7259, 10870],\n", " [ 2642, 8477, 7259, 10870, 10530],\n", " [ 8477, 7259, 10870, 10530, 7506],\n", " [ 7259, 10870, 10530, 7506, 12968],\n", " [10870, 10530, 7506, 12968, 7997],\n", " [10530, 7506, 12968, 7997, 1911],\n", " [ 7506, 12968, 7997, 1911, 12479],\n", " [12968, 7997, 1911, 12479, 11129],\n", " [ 7997, 1911, 12479, 11129, 13069],\n", " [ 1911, 12479, 11129, 13069, 11797],\n", " [12479, 11129, 13069, 11797, 5819],\n", " [11129, 13069, 11797, 5819, 6268],\n", " [13069, 11797, 5819, 6268, 2807],\n", " [11797, 5819, 6268, 2807, 7831],\n", " [ 5819, 6268, 2807, 7831, 12893]], device='cuda:0')" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_ids[:30]" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([57022, 5])" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_ids.shape" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "corpora_valid = open(pan_tadeusz_path_valid).read()" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "corpora_valid_tokenized = list(tokenize(corpora_valid,lowercase = True))" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "valid_ids = get_token_id(corpora_valid_tokenized)" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [], "source": [ "valid_ids = torch.tensor(get_samples(valid_ids), dtype = torch.long, device = device)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## model" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [], "source": [ "# https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [], "source": [ "class LSTM(torch.nn.Module):\n", "\n", " def __init__(self):\n", " super(LSTM, self).__init__()\n", " self.emb = torch.nn.Embedding(len(vocab_itos),100)\n", " self.rec = torch.nn.LSTM(100, 256, 1, batch_first = True)\n", " self.fc1 = torch.nn.Linear( 256 ,len(vocab_itos))\n", " #self.dropout = torch.nn.Dropout(0.5)\n", "\n", " def forward(self, x):\n", " emb = self.emb(x)\n", " #emb = self.dropout(emb)\n", " output, (h_n, c_n) = self.rec(emb)\n", " hidden = h_n.squeeze(0)\n", " out = self.fc1(hidden)\n", " #out = self.dropout(out)\n", " return out" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [], "source": [ "lm = LSTM().to(device)" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [], "source": [ "criterion = torch.nn.CrossEntropyLoss()" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [], "source": [ "optimizer = torch.optim.Adam(lm.parameters(),lr=0.0001)" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [], "source": [ "BATCH_SIZE = 128\n", "EPOCHS = 15" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [], "source": [ "def get_ppl(dataset_ids):\n", " lm.eval()\n", "\n", " batches = 0\n", " loss_sum =0\n", " acc_score = 0\n", "\n", " for i in range(0, len(dataset_ids)-BATCH_SIZE+1, BATCH_SIZE):\n", " X = dataset_ids[i:i+BATCH_SIZE,:NGRAMS-1]\n", " Y = dataset_ids[i:i+BATCH_SIZE,NGRAMS-1]\n", " predictions = lm(X)\n", " \n", " # equally distributted\n", " # predictions = torch.zeros_like(predictions)\n", " \n", " loss = criterion(predictions,Y)\n", "\n", " loss_sum += loss.item()\n", " batches += 1\n", "\n", " return np.exp(loss_sum / batches)" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "2c8985983fcd4ec5b93ba32b0345d000", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=445.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "epoch: 0\n", "train ppl: 2296.6914856482526\n", "valid ppl: 528.9542436139727\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "d2acde1a6e19416cb1eaee8f40f0114f", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=445.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "epoch: 1\n", "train ppl: 2093.302103954666\n", "valid ppl: 514.4726844027333\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "c9df3b66aede47febbe6852ca62eed17", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=445.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "epoch: 2\n", "train ppl: 2014.09679023559\n", "valid ppl: 510.12146471773366\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "d5272c8b25f74c5a9383a829ed30308e", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=445.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "epoch: 3\n", "train ppl: 1939.0594855086504\n", "valid ppl: 509.1060151440451\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "d6656f2e796b44b7b03d204ecd89f89b", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=445.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "epoch: 4\n", "train ppl: 1854.4566511885196\n", "valid ppl: 510.02244291272973\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "86419de761c1443c86c70726099b9838", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=445.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "epoch: 5\n", "train ppl: 1755.030202547313\n", "valid ppl: 508.494174178397\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "5d12b6fcfc0045c9b8bac58343036dff", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=445.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "epoch: 6\n", "train ppl: 1646.180912657662\n", "valid ppl: 506.06383737670035\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "d6593e6387d5439a9c1301a366fecda6", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=445.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "epoch: 7\n", "train ppl: 1533.0501876139222\n", "valid ppl: 504.08067276707567\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "32742824eb00458b8586795a7fb56d84", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=445.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "epoch: 8\n", "train ppl: 1420.680717507558\n", "valid ppl: 502.6906095632547\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "5ee78d1ebd03415885d4ea2663c0ee7f", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=445.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "epoch: 9\n", "train ppl: 1311.1083504083306\n", "valid ppl: 503.5230045363773\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "b30185826c4341b088934df92839852c", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=445.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "epoch: 10\n", "train ppl: 1203.498635587493\n", "valid ppl: 505.7599916969862\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "7eeeeb3ef93a45f3a6709a8579164c79", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=445.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "epoch: 11\n", "train ppl: 1100.0681613054269\n", "valid ppl: 507.6071195979723\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "919fd04ac9ed476e931149c15dc460db", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=445.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "epoch: 12\n", "train ppl: 1003.217414775517\n", "valid ppl: 510.07952767103245\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "a30dc2c9bd194ead94be7de3275a1e5e", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=445.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "epoch: 13\n", "train ppl: 912.2987798296267\n", "valid ppl: 512.8275727599236\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "94f2bac40c83416fad59742f2e2adff4", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=445.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "epoch: 14\n", "train ppl: 826.911431868259\n", "valid ppl: 516.1525759633064\n", "\n" ] } ], "source": [ "history_ppl_train = []\n", "history_ppl_valid = []\n", "for epoch in range(EPOCHS):\n", " \n", " batches = 0\n", " loss_sum =0\n", " acc_score = 0\n", " lm.train()\n", " #for i in range(0, len(train_ids)-BATCH_SIZE+1, BATCH_SIZE):\n", " for i in tqdm(range(0, len(train_ids)-BATCH_SIZE+1, BATCH_SIZE)):\n", " X = train_ids[i:i+BATCH_SIZE,:NGRAMS-1]\n", " Y = train_ids[i:i+BATCH_SIZE,NGRAMS-1]\n", " predictions = lm(X)\n", " loss = criterion(predictions,Y)\n", " \n", " \n", " \n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", " \n", " loss_sum += loss.item()\n", " batches += 1\n", " \n", " ppl_train = get_ppl(train_ids)\n", " ppl_valid = get_ppl(valid_ids)\n", " \n", " history_ppl_train.append(ppl_train)\n", " history_ppl_valid.append(ppl_valid)\n", " \n", " print('epoch: ', epoch)\n", " print('train ppl: ', ppl_train)\n", " print('valid ppl: ', ppl_valid)\n", " print()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## parametry modelu" ] }, { "cell_type": "code", "execution_count": 35, "metadata": { "scrolled": false }, "outputs": [ { "data": { "text/plain": [ "[Parameter containing:\n", " tensor([[ 0.4949, -1.2472, -0.7167, ..., -0.0801, -2.1905, 0.9790],\n", " [ 0.2070, 1.7394, 0.8255, ..., -0.5796, -0.3776, 0.8831],\n", " [-1.4559, 1.1073, -0.4904, ..., -1.2919, -2.2661, -0.5476],\n", " ...,\n", " [-0.3706, 0.2133, 0.0484, ..., -0.5792, -0.5769, -0.6941],\n", " [ 0.5502, -0.1212, -2.0879, ..., 0.6764, -0.5961, -0.6282],\n", " [ 0.8362, 0.2193, -0.0807, ..., 2.7741, -0.2589, 0.3310]],\n", " device='cuda:0', requires_grad=True),\n", " Parameter containing:\n", " tensor([[ 0.0252, -0.0744, 0.0817, ..., -0.0559, 0.0896, 0.0208],\n", " [ 0.0423, 0.0329, -0.0610, ..., -0.0009, 0.0169, -0.0361],\n", " [ 0.0507, -0.0838, 0.0520, ..., 0.0395, 0.0067, 0.0173],\n", " ...,\n", " [ 0.0669, 0.0430, -0.0306, ..., 0.0096, 0.0619, -0.0992],\n", " [-0.0153, -0.0888, 0.0580, ..., -0.0433, 0.0399, -0.0494],\n", " [-0.0067, -0.0053, -0.0242, ..., 0.0017, -0.0306, -0.0972]],\n", " device='cuda:0', requires_grad=True),\n", " Parameter containing:\n", " tensor([[ 0.0212, -0.0425, -0.0329, ..., -0.0206, 0.0839, 0.0286],\n", " [ 0.0952, 0.0298, -0.1211, ..., -0.0468, -0.0233, -0.0620],\n", " [ 0.0108, 0.0422, -0.0492, ..., -0.0288, -0.0231, 0.0078],\n", " ...,\n", " [ 0.0253, 0.0154, -0.0765, ..., -0.0025, 0.0057, -0.0408],\n", " [ 0.0892, -0.0928, -0.1039, ..., -0.1531, -0.0011, 0.0180],\n", " [ 0.1341, 0.0666, -0.0548, ..., -0.0573, -0.0376, -0.0813]],\n", " device='cuda:0', requires_grad=True),\n", " Parameter containing:\n", " tensor([0.1115, 0.0423, 0.1438, ..., 0.0361, 0.0297, 0.0956], device='cuda:0',\n", " requires_grad=True),\n", " Parameter containing:\n", " tensor([0.0786, 0.1201, 0.0857, ..., 0.1177, 0.1319, 0.0886], device='cuda:0',\n", " requires_grad=True),\n", " Parameter containing:\n", " tensor([[-0.0158, 0.0236, -0.0958, ..., -0.0906, -0.0678, 0.0057],\n", " [-0.0871, -0.0788, 0.1217, ..., -0.0231, -0.0102, 0.0220],\n", " [ 0.0265, -0.0680, -0.0219, ..., -0.0520, -0.0565, 0.0628],\n", " ...,\n", " [-0.0618, 0.0232, 0.0898, ..., 0.1069, -0.0112, 0.0103],\n", " [-0.0489, 0.0708, 0.0546, ..., 0.1186, -0.0987, 0.1411],\n", " [-0.0764, 0.0463, 0.0947, ..., 0.1104, -0.0312, 0.1118]],\n", " device='cuda:0', requires_grad=True),\n", " Parameter containing:\n", " tensor([ 0.0299, -0.0551, -0.0323, ..., -0.0371, -0.0297, -0.0157],\n", " device='cuda:0', requires_grad=True)]" ] }, "execution_count": 35, "metadata": {}, "output_type": "execute_result" } ], "source": [ "list(lm.parameters())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### krzywe uczenia" ] }, { "cell_type": "code", "execution_count": 36, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "[528.9542436139727,\n", " 514.4726844027333,\n", " 510.12146471773366,\n", " 509.1060151440451,\n", " 510.02244291272973,\n", " 508.494174178397,\n", " 506.06383737670035,\n", " 504.08067276707567,\n", " 502.6906095632547,\n", " 503.5230045363773,\n", " 505.7599916969862,\n", " 507.6071195979723,\n", " 510.07952767103245,\n", " 512.8275727599236,\n", " 516.1525759633064]" ] }, "execution_count": 36, "metadata": {}, "output_type": "execute_result" } ], "source": [ "history_ppl_valid" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[,\n", " ]" ] }, "execution_count": 38, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD4CAYAAAAAczaOAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAoWElEQVR4nO3dd5xU9b3/8ddnG71L3QUWBFFEaRtAOhoVe4kkVlARNPZoojH+bpLfvdd784smKho1qDQjGGLDGgvSVNoiRUSBpS/g0pRetnx+f8wBx2UXlm0zO/N+Ph7zmDOfOWfms7D7Pme+58w55u6IiEh8SIh0AyIiUnkU+iIicUShLyISRxT6IiJxRKEvIhJHkiLdwPGcdNJJnp6eHuk2RESqlAULFmxz98aF61Ef+unp6WRmZka6DRGRKsXM1hVV1/COiEgcUeiLiMQRhb6ISBxR6IuIxBGFvohIHFHoi4jEEYW+iEgcidnQnzh3PTNXbI10GyIiUSUmQ/9QXgEvz13HiAmZzF61PdLtiIhEjZgM/ZSkBCbc3IPWjWoyfPx8MtfuiHRLIiJR4bihb2YtzWyamX1tZl+Z2T1B/VEz+8bMlpjZG2ZWP6inm9l+M1sU3J4Le63uZvalmWWZ2Sgzs4r6wRrVrsY/bulJs7rVuXHsfBZt+L6i3kpEpMooyZZ+HnC/u58G9ALuMLOOwEdAJ3c/E1gBPBS2zCp37xLcbgurPwuMBNoHt8Hl8UMUp0md6kwc0YuGtVIY+uJclm7cWZFvJyIS9Y4b+u6+2d2/CKZ3A18Dqe7+obvnBbPNAdKO9Tpm1hyo6+6zPXRh3gnA5WVpviSa1avOxBE9qVM9mRtenMvyb3dX9FuKiEStExrTN7N0oCswt9BTNwPvhz1uY2YLzWyGmfULaqlAdtg82UGtqPcZaWaZZpa5dWvZj8BJa1CTiSN6kpKUwHUvzCFry54yv6aISFVU4tA3s9rAa8C97r4rrP4woSGgl4PSZqCVu3cF7gMmmlldoKjxey/qvdx9tLtnuHtG48ZHnQ66VFo3qsXEEb0A49rn57B2295yeV0RkaqkRKFvZsmEAv9ld389rD4MuBi4Lhiywd0Puvv2YHoBsAo4hdCWffgQUBqwqTx+iJI6uXFtJo7oSV6Bc90Lc8n+bl9lvr2ISMSV5OgdA14Evnb3v4bVBwMPApe6+76wemMzSwym2xLaYbva3TcDu82sV/CaQ4Ep5frTlMApTevw0vAe7D6QyzXPz2Hzzv2V3YKISMSUZEu/D3ADcHbYYZgXAk8DdYCPCh2a2R9YYmaLgVeB29z98IHyvwReALIIfQII3w9QaU5vUY+Xhvfk+725XPv8XLbsOhCJNkREKp0FozJRKyMjwyvqcokL1u3ghhfnkVq/Bq+M7EWj2tUq5H1ERCqbmS1w94zC9Zj8Rm5JdW/dkDE3/oQN3+3j+hfn8f2+Q5FuSUSkQsV16AP0atuI54dmsGrrHm54cR67DuRGuiURkQoT96EP0K99Y567vhvffLuLYWPmsedg3vEXEhGpghT6gbNPbcpT13RjSfZObh47n32HFPwiEnsU+mEGd2rGE7/oQua6HYyYkMmB3PxItyQiUq4U+oVc0rkFjw3pzOertnPbPxZwME/BLyKxQ6FfhCu7pfE/V5zB9OVbuXPiQnLzCyLdkohIuVDoF+OaHq34z8tO56NlOdz7yiLyFPwiEgOSIt1ANBt6VjqH8gr473e/JiUpgceGdCYxocKu+yIiUuEU+sdxS7+2HMwr4NEPlpOSmMD/XnkGCQp+EamiFPolcMegdhzMzWfUJ1kkJxn/dVknKvBKjyIiFUahX0K/OvcUDuYX8PcZq8lc+x3X9mzF5V1TqVs9OdKtiYiUWFyfcO1EuTv/WpDNhNlrWbpxF9WTE7jkzBZc27MVXVrW19a/iESN4k64ptAvpSXZ3zNx7nreWryJfYfyOa153dDWf5cW1NHWv4hEmEK/guw+kMuURZuYOHc9yzbvokZyIpd2Dm39n5lWT1v/IhIRCv0K5u4szt7JpGDrf39uPqe3CG39X9YlldrVtPtERCqPQr8S7TqQy5SFG3l57nq++XY3NVMSuaxLC67t0Zoz0upFuj0RiQMK/QhwdxZtCI39v71kEwdyCzgjtR7X9mzFpZ1bUEtb/yJSQRT6EbZzfy5TFm1kYrD1Xyslkcu6pnJtj1Z0StXWv4iUr1KHvpm1BCYAzYACYLS7P2lmDYF/AunAWuDn7v5dsMxDwHAgH7jb3T8I6t2BcUAN4D3gHj9OA7ES+oe5O1+sD239v7NkEwfzCuicVo9rerTi4s4tNPYvIuWiLKHfHGju7l+YWR1gAXA5cCOww93/ZGa/BRq4+4Nm1hGYBPQAWgAfA6e4e76ZzQPuAeYQCv1R7v7+sd4/1kI/3M59ubyxMJuJ89azImcPNZITufCM5vw8I40ebRrqyB8RKbXiQv+4m5XuvhnYHEzvNrOvgVTgMmBgMNt4YDrwYFB/xd0PAmvMLAvoYWZrgbruPjtoaAKhlccxQz+W1auZzI192jCsdzpfrP+eVxds4O3Fm3nti2xaN6rJVd3S+Fn3NFrUrxHpVkUkRpzQWIKZpQNdgblA02CFgLtvNrMmwWyphLbkD8sOarnBdOF6Ue8zEhgJ0KpVqxNpsUoyM7q3bkD31g34/cWn8/7SzfwrM5u/fLSCv368gr7tTmJIRkvO69iU6smJkW5XRKqwEoe+mdUGXgPudfddxxh6KOoJP0b96KL7aGA0hIZ3StpjLKiRksiV3dK4slsaG3bs418LsnltQTZ3T1pI3epJXNYllSEZaZyRqi9+iciJK1Hom1kyocB/2d1fD8o5ZtY82MpvDmwJ6tlAy7DF04BNQT2tiLoUo2XDmtx37ince057Zq/ezuTMDUzO3MBLc9ZxarM6XNU9jSu6ptKodrVItyoiVURJduQaoTH7He5+b1j9UWB72I7chu7+gJmdDkzkhx25U4H2wY7c+cBdhIaH3gOecvf3jvX+sbwjtzR27s/lnSWbmJyZzeIN35OUYJxzWhOGdG/JwA6NSUrUxdBEpGxH7/QFZgFfEjpkE+B3hIJ7MtAKWA8McfcdwTIPAzcDeYSGg94P6hn8cMjm+8Bd8XbIZnlakbObVxdk8/oX2Wzbc4iTalfjZ91Cwz/tmtSJdHsiEkH6clYMy80vYPryrUzO3MC0b7aQV+B0aVmfIRlpXNK5hc75LxKHFPpxYtueg7y5cCOTMzewImcPNVMSubJbKjf2TtfWv0gcUejHGXdnSfZO/jFnHVMWb+JQXgH92p/EsLPSGXRqE13gXSTGKfTj2PY9B3ll/gZemr2Ob3cdoFXDmgw9qzVDMlpSr4aGfkRikUJfyM0v4MOvchj3+Rrmr/2OmimJ/KxbGsN6t9bQj0iMUejLjyzduJNxn6/lrbChnxt7pzOoQxMSNPQjUuUp9KVIhYd+WjeqydCz0hmSkaajfkSqMIW+HFNufgEffPUt4z5bS+a68KGfdNo1qR3p9kTkBCn0pcSODP0s2sSh/NDQz0190hl4ioZ+RKoKhb6csG17DvLKvPW8NGcdObsOauhHpApR6Eup5eYX8O+l3zL+8x+Gfq7p0YpbB7SlSZ3qkW5PRIqg0Jdy8WX2TsZ+toYpizeRnGhc37M1tw44mcZ1dKZPkWii0JdytXbbXkZ9spI3F26kWlIiN5zVmlv7t9VpnkWihEJfKsTqrXt46pMspiwKhf/Q3q25tf/JNKyVEunWROKaQl8qVNaWPTz1yUreWryJGsmJDOudzsh+bWmg8BeJCIW+VIqsLbt5cmoW7yzZRM3kRG7sk86Ifm2pX1PhL1KZFPpSqVbk7ObJqSt5d8lmaldL4qY+6dzSty31aupQT5HKoNCXiFj+7W6enLqC9778ljrVkripbxuG922js3uKVDCFvkTU15t38eTHK/n3V99Sp3oSw/u24ea+bfQlL5EKUlzoH/cq2mY2xsy2mNnSsNo/zWxRcFtrZouCerqZ7Q977rmwZbqb2ZdmlmVmo4ILrkucOK15XZ67oTvv3t2Xs9o24omPV9L3T58waupKdh/IjXR7InGjJBdG7w/sASa4e6cinv8LsNPd/9PM0oF3iplvHnAPMAd4Dxh1+ILpx6It/di0dONOnvh4JR9/nUO9GsmM6NeGYb3TqaMtf5FyUeotfXefCewo5kUN+Dkw6Thv3hyo6+6zPbSWmQBcXoK+JUZ1Sq3HC8MyePvOvmS0bsBjH66g35+n8fzM1RzMy490eyIx67ihfxz9gBx3XxlWa2NmC81shpn1C2qpQHbYPNlBrUhmNtLMMs0sc+vWrWVsUaLZGWn1ePHGnzDljj6cmVafR977mp/+dQZvLd5EtO9vEqmKyhr61/DjrfzNQCt37wrcB0w0s7pAUeP3xf5Fu/tod89w94zGjRuXsUWpCjq3rM+Em3vw0vAe1K6WzN2TFnL53z5j7urtkW5NJKaUOvTNLAm4Evjn4Zq7H3T37cH0AmAVcAqhLfu0sMXTgE2lfW+JXf3aN+adu/ry2JDO5Ow6yC9Gz2HEhExWbd0T6dZEYkJZtvR/Cnzj7keGbcyssZklBtNtgfbAanffDOw2s17BfoChwJQyvLfEsMQE46ruaUz79UB+c34HZq/aznmPz+T/vPkl2/YcjHR7IlVaSQ7ZnATMBjqYWbaZDQ+eupqjd+D2B5aY2WLgVeA2dz+8E/iXwAtAFqFPAMc9ckfiW42URO4Y1I7pvxnIdT1bMWneBgb8eRpPf7KS/Ye0s1ekNPTlLKkyVm3dw/97/xs+XJZDs7rVue+8U/hZtzQSdQlHkaOU+pBNkWhxcuPajB6aweRbz6Jpveo88OoSLho1i5krdISXSEkp9KXK6dGmIW/e3punrunK3kN5DB0zjxtenMvXm3dFujWRqKfQlyrJzLikcws+vm8A/+ei01iSvZMLR83i1/9azOad+yPdnkjU0pi+xISd+3L52/Qsxn22loQEuKVvW24d0FandZC4pTF9iWn1aibzuwtPY+r9AzivYzOenpbFwEen89LsteTmF0S6PZGoodCXmNKyYU1GXdOVKXf0oV2T2vzHlK84/4mZfPJNjk7rIIJCX2JU55b1eWVkL54fmgEON4/LZNjY+azM2R3p1kQiSqEvMcvMOLdjU/59b3/+4+KOLFz/HYOfnMUfpizl+32HIt2eSEQo9CXmpSQlMLxvG2b8ZhDX9GjJS3PWMeDR6Yz7bI3G+yXuKPQlbjSslcJ/X34G793Tj06pdfnj28u44MlZTF++JdKtiVQahb7EnVOb1eUfw3vy/NAM8vILuHHsfG4aO09n8pS4oNCXuHR4vP+DX/XndxeeSuba7zj/8Zn859vL2LlP1+yV2KXQl7hWLSmRkf1PZtpvBjIkoyVjP1/DwMem8dLsteRpvF9ikEJfBDipdjX+98ozePeufnRoVof/mPIVF436lE9Xbot0ayLlSqEvEqZji7pMGtGL567vxr7cPK5/cS63jM9kzba9kW5NpFwo9EUKMTMGd2rOR78awIODT2X2qm2c9/gMHnl3GbsOaLxfqjaFvkgxqicn8suBofH+K7qm8sKnaxj06HQmzl1PfoFO6SBVk0Jf5Dia1KnOn6/qzNt39qVt41r87o0vuWjULD5fpfF+qXoU+iIl1Cm1HpNvPYu/XduN3QfyuPb5udz6Uibrtmu8X6qOklwYfYyZbTGzpWG1P5rZRjNbFNwuDHvuITPLMrPlZnZ+WL27mX0ZPDfKzHRhU6lyzIyLzmzO1PsH8JvzOzBr5TZ++tcZ/M97X2u8X6qEkmzpjwMGF1F/3N27BLf3AMysI3A1cHqwzDNmlhjM/ywwEmgf3Ip6TZEqoXpyIncMasf0X4fG+5+ftZpBj07n5bnrdHy/RLXjhr67zwR2lPD1LgNecfeD7r4GyAJ6mFlzoK67z/bQSc0nAJeXsmeRqNGk7g/j/Sc3qc3Dbyzl4qd0fL9Er7KM6d9pZkuC4Z8GQS0V2BA2T3ZQSw2mC9eLZGYjzSzTzDK3bt1ahhZFKken1Hr8c2Qvnr2uG3sPHT6+fz6rdT4fiTKlDf1ngZOBLsBm4C9Bvahxej9GvUjuPtrdM9w9o3HjxqVsUaRymRkXnPHD8f1zVu/gvMdn8l/v6Hw+Ej1KFfrunuPu+e5eADwP9AieygZahs2aBmwK6mlF1EVizpHj+389kCEZaYz5LHQ+nwk6n49EgVKFfjBGf9gVwOEje94CrjazambWhtAO23nuvhnYbWa9gqN2hgJTytC3SNRrXKca/3vlmbxzV19ObVaX30/5iguenMWMFRqylMgpySGbk4DZQAczyzaz4cCfg8MvlwCDgF8BuPtXwGRgGfBv4A53zw9e6pfAC4R27q4C3i/vH0YkGp3eoh4TR/Rk9A3dOZRfwLAx87hp7Dyytmi8XyqfhQ6miV4ZGRmemZkZ6TZEysXBvHwmfL6OUVNXsi83nxt6teaec9rToFZKpFuTGGNmC9w9o3Bd38gVqUTVkhIZ0b8t038zkKt/0pIJs9cy8LHpjNX1eqWSKPRFIqBR7Wo8ckXoer1npNbj/769jPOfmMkn3+QQ7Z++pWpT6ItE0KnN6vLS8B68OCwDHG4el8nQMfNYtmlXpFuTGKXQF4kwM+Oc05ry73v78x8Xd2RJ9k4uemoWD7y6mJxdByLdnsQY7cgViTI79+Xy9LSVjP98HYkJxoj+bbm1f1tqVUuKdGtShWhHrkgVUa9mMg9f1JGP7xvAOac1YdTUlQx8bDqT5uniLVJ2Cn2RKNWqUU2evrYbr9/em1YNa/LQ619y4ZOzmL58i3b2Sqkp9EWiXLdWDXj1trN49rpuHMjL58ax87WzV0pNoS9SBYSfzO33F3fky42hnb2/+ddivt2pnb1SctqRK1IFaWevHI925IrEkMM7e6feP4CfdmzKqKkrGfBoaGevzuQpx6LQF6nCWjasyVPXdOWN23uT3ijY2TtqFtO0s1eKodAXiQFdWzXgX7edxXPXd+NQXgE3jZ3PDS9qZ68cTaEvEiPMjMGdmvPhrwbwh0s6snSTdvbK0bQjVyRG7dyfyzPTshj72VoSEuCWvm25dUBb6lRPjnRrUgm0I1ckztSrkcxDF57G1PsHcG7HZjw9LYsBj05n/OdrOZSnnb3xSqEvEuMO7+x9684+dGhahz+89RXnPj6Dd5ds1s7eOKTQF4kTZ6bVZ+KInoy96SdUT0rkjolfcPkznzN39fZItyaVqCTXyB1jZlvMbGlY7VEz+8bMlpjZG2ZWP6inm9l+M1sU3J4LW6Z7cF3dLDMbFVwgXUQqkZkxqEMT3runH49edSY5Ow/wi9FzuGX8fFbm7I50e1IJSrKlPw4YXKj2EdDJ3c8EVgAPhT23yt27BLfbwurPAiOB9sGt8GuKSCVJTDCGZLRk+m8G8sDgDsxdvYPzn5jJb19bonP4x7jjhr67zwR2FKp96O55wcM5QNqxXsPMmgN13X22hwYRJwCXl6pjESk31ZMTuX1gO2Y8MIgbe7fhtS+yGfDoNB77YDm7D+RGuj2pAOUxpn8z8H7Y4zZmttDMZphZv6CWCmSHzZMd1EQkCjSslcLvL+nI1PsGcl7YkT7jPlujI31iTJlC38weBvKAl4PSZqCVu3cF7gMmmlldoKjx+2IPGzCzkWaWaWaZW7duLUuLInICWjWqyaiwI33++PYyzn18Bu8s2aQjfWJEqUPfzIYBFwPXBUM2uPtBd98eTC8AVgGnENqyDx8CSgM2Fffa7j7a3TPcPaNx48albVFESqnwkT53TlzI5c98zhwd6VPllSr0zWww8CBwqbvvC6s3NrPEYLotoR22q919M7DbzHoFR+0MBaaUuXsRqTBFHelz9eg5DB83nxU60qfKKskhm5OA2UAHM8s2s+HA00Ad4KNCh2b2B5aY2WLgVeA2dz+8E/iXwAtAFqFPAOH7AUQkShU+0mfemh0MfmImD766ROf0qYJ07h0ROSE79h7i6U+yeGnOWhITjGG907mt/8k0qJUS6dYkTHHn3lHoi0iprN++j8c/XsGbizZSOyWJkf3bclPfNtTW1buigkJfRCrE8m9385cPl/Phshwa1Urh9kHtuK5nK6onJ0a6tbim0BeRCrVw/Xc89uFyPsvaTot61bnnp+35Wbc0khJ1iq9I0KmVRaRCdW3VgJdv6cXLt/SkSd3qPPjal5z7+EzeXryJgoLo3riMJwp9ESlXfdqdxBu39+b5oRmkJCZw16SFXPTUp3zyTY6+4BUFFPoiUu7MjHM7NuW9e/rxxC+6sPdgHjePy2TIc7N1KucIU+iLSIVJTDAu75rK1PsH8MgVndjw3T5+MXoOQ8fM48vsnZFuLy5pR66IVJoDuflMmL2WZ6av4vt9uVzQqRn3n3cK7ZrUiXRrMUdH74hI1Nh9IJcXZq3hhVmr2Z+bz5Xd0rjnnPa0bFgz0q3FDIW+iESdHXsP8ez0LMbPXoe7c22PVtxxdjua1Kke6daqPIW+iEStzTv3M2pqFpMzN5CSmMCw3umM6NeGRrWrRbq1KkuhLyJRb+22vTz+8QreWryJ6kmJXN+rFSP6t9WWfyko9EWkysjasodnpmXx5qKNJCcmcG3PVtza/2Sa1VP4l5RCX0SqnLXb9vK3aVm8sXAjCWb84ictuW3gyaTWrxHp1qKeQl9EqqwNO/bxzPRVvLpgAwBXdU/j9oHtdLTPMSj0RaTK2/j9fv4+YxWvzNtAvjtXdk3ljkHtSD+pVqRbizoKfRGJGTm7DvD3Gat5ee46cvMLuKxLKPzbNakd6daihkJfRGLOlt0HeGHWGl6avY4DeflcdEZz7jq7PR2a6Ru+Cn0RiVnb9xzkxU/XMP7ztew9lM8FnZpx59ntOL1FvUi3FjGlPp++mY0xsy1mtjSs1tDMPjKzlcF9g7DnHjKzLDNbbmbnh9W7m9mXwXOjzMzK4wcTEWlUuxoPDD6Vz357Nnef055Ps7Zx0ahPuWV8Jkuyv490e1GlJGfZHAcMLlT7LTDV3dsDU4PHmFlH4Grg9GCZZ8zs8DXTngVGAu2DW+HXFBEpk/o1U7jv3FP49MGzue/cU5i/dgeXPv0ZN46dx4J130W6vahw3NB395nAjkLly4DxwfR44PKw+ivuftDd1wBZQA8zaw7UdffZHhpPmhC2jIhIuapXIzm0xf/gIB4Y3IEl2Tv52bOfc+3zc5ixYmtcX8yltOfTb+rumwGC+yZBPRXYEDZfdlBLDaYL14tkZiPNLNPMMrdu3VrKFkUk3tWpnsztA9vx6YODePjC01i9dS/Dxszjgidn8ebCjeTmF0S6xUpX3hdRKWqc3o9RL5K7j3b3DHfPaNy4cbk1JyLxqWZKEiP6t2XmA4N4bEhn8guce/+5iIGPTufFT9ew92BepFusNKUN/ZxgyIbgfktQzwZahs2XBmwK6mlF1EVEKk1KUgJXdU/jg3v7M+bGDFIb1OC/3llG7z99wmMfLGfbnoORbrHClTb03wKGBdPDgClh9avNrJqZtSG0w3ZeMAS028x6BUftDA1bRkSkUiUkGGef2pTJt57F67f35qy2jfjb9Cx6/+kTfvfGl6zZtjfSLVaY4x6nb2aTgIHASUAO8AfgTWAy0ApYDwxx9x3B/A8DNwN5wL3u/n5QzyB0JFAN4H3gLi/B3hQdpy8ilWH11j08P2sNr32RTW5+AYNPb8atA06mS8v6kW6tVPTlLBGREtiy+wDjP1/LS7PXsetAHj3bNOS2ASczsENjqtLXixT6IiInYM/BPF6Zt54XP13D5p0H6NC0DiP7t+WSzi1ISSrvY2DKn0JfRKQUcvMLeHvxJv4+YzXLc3bTvF51hvdtw9U9WlG7WlKk2yuWQl9EpAzcnekrtvL3GauYs3oHdaoncUOv1tzYJz0qL+eo0BcRKSeLNnzP6JmreH/ptyQnJHBJ5xbc1CedTqnRc4I3hb6ISDlbs20vYz5dw6sLstmfm0+P9Ibc1Cedczs2JSkxsuP+Cn0RkQqyc38uk+dvYPzstWR/t5/U+jUYelZrfvGTltSvmRKRnhT6IiIVLL/A+WhZDuM+X8Oc1TuokZzIFd1Sual3Ou2bVu6FXRT6IiKVaNmmXYz7fA1vLtrEobwC+rU/iRt7pzOoQxMSEir+eH+FvohIBGzfc5BJ89bz0px15Ow6SHqjmgzrnc5V3dOoUz25wt5XoS8iEkG5+QW8v/Rbxn62hoXrv6d2tSSGZKQx7Kx00k+qVe7vp9AXEYkSizZ8z9jP1vDuks3ku3N2hybc1KcNfdo1KrdTPSj0RUSiTM6uA7w8Zx0vz13P9r2HOKVpbW7s3YYruqZSIyXx+C9wDAp9EZEodSA3n7cXb2LsZ2tZtnkX9Wokc3WPltw+sB31apRu3L+40I/+swaJiMS46smJDMloybt392XyrWfR++RGvJqZTUoFfMEres8WJCISZ8yMHm0a0qNNQ/YezCvzEE9RtKUvIhKFalXQGTwV+iIicUShLyISRxT6IiJxpNShb2YdzGxR2G2Xmd1rZn80s41h9QvDlnnIzLLMbLmZnV8+P4KIiJRUqfcUuPtyoAuAmSUCG4E3gJuAx939sfD5zawjcDVwOtAC+NjMTnH3/NL2ICIiJ6a8hnfOAVa5+7pjzHMZ8Iq7H3T3NUAW0KOc3l9EREqgvEL/amBS2OM7zWyJmY0xswZBLRXYEDZPdlA7ipmNNLNMM8vcunVrObUoIiJlDn0zSwEuBf4VlJ4FTiY09LMZ+MvhWYtYvMhzQLj7aHfPcPeMxo0bl7VFEREJlMeW/gXAF+6eA+DuOe6e7+4FwPP8MISTDbQMWy4N2FQO7y8iIiVUHqF/DWFDO2bWPOy5K4ClwfRbwNVmVs3M2gDtgXnl8P4iIlJCZfqer5nVBM4Fbg0r/9nMuhAaull7+Dl3/8rMJgPLgDzgDh25IyJSucoU+u6+D2hUqHbDMeZ/BHikLO8pIiKlp2/kiojEEYW+iEgcUeiLiMQRhb6ISBxR6IuIxBGFvohIHFHoi4jEEYW+iEgcUeiLiMQRhb6ISBxR6IuIxBGFvohIHFHoi4jEEYW+iEgcUeiLiMQRhb6ISBxR6IuIxJEyXTkrqk25E/ZthxoNoWYDqNEgmG4Ydh/UkqtHulsRkUpR1mvkrgV2A/lAnrtnmFlD4J9AOqFr5P7c3b8L5n8IGB7Mf7e7f1CW9z+mvIPw/XrYtAj274C8A8XPm1wzFP41GgQriLCVQ40GR68okqpDYgokJge3FEhIhgR9cBKJa+5QkPfDLT8XCvKDx7nBfX5Qzyv6lh82ferF5Z4r5bGlP8jdt4U9/i0w1d3/ZGa/DR4/aGYdgauB04EWwMdmdkqFXRz9Z8//+HHufti3I7QC2LcD9n/34+nwWs5Xofv934EXlPw9LbHolcHh6cSk4D4FEsKmE5NC81kCmAFWxDTF1I81TejePXQj/L6giFpQ/1GNImoeel1L+OG9jkwXuoX3U9w84a+TkBj6d/zRfULo3+uoWtjjhKRgOqGI5cPqR5ZLCKufyHOHp2NoBV9QAJ4fCqPwe/ejawX5od+HI/MU/FArfDtSD3u+oKCIWhHLHw5Kz/9xf0fV8ororVDtqJ8hL7T8kdfK+2G+H9WKmefI47B5Dt/KO84ezoGE8h2JqIjhncuAgcH0eGA68GBQf8XdDwJrzCwL6AHMroAejpZcA+qlhm4lVVAAB3cFK4fvflgR5B0IranzcyH/UGgNfni6yHrYcwVh07n7f1wvHLrFhXGx08XM/6MVgv14+kf3CcepEfYcP6w4ivqDP3LzY8+Hl8t/b0QccyVXeEVX1ErPCs1/+DWKWkmH/98e47ljrdgLB/HhAKvqjqzwD6/8i9ogSCi0cZAUWnH/6HEiJKUcf54jr53041pC8g+1xKSw58NuiclhyySFLZMY9lz4/Cnl/s9V1tB34EMzc+Dv7j4aaOrumwHcfbOZNQnmTQXmhC2bHdSOYmYjgZEArVq1KmOLZZCQADXqh24NI9dGTPOwQDrWFmVxW37Fbs3lFzFv2BZm+NblUUFYzNZqQbD1G/4aFF6Zhf88RazgjloZFp7Op+gV8rFWzBw975EVUHjtcGCFTxfxaehHn2iK+FQV/mmr8Kejw5+CjqzECn+CSihiuWKWDQ/zIj/tJcbWJ65KUtbQ7+Pum4Jg/8jMvjnGvFZErcjNvGDlMRogIyOjCm8KynEd+eSRQCwfVyASLcq0mnT3TcH9FuANQsM1OWbWHCC43xLMng20DFs8DdhUlvcXEZETU+rQN7NaZlbn8DRwHrAUeAsYFsw2DJgSTL8FXG1m1cysDdAemFfa9xcRkRNXls/TTYE3LLRTLwmY6O7/NrP5wGQzGw6sB4YAuPtXZjYZWAbkAXdU2JE7IiJSpFKHvruvBjoXUd8OnFPMMo8Aj5T2PUVEpGy061tEJI4o9EVE4ohCX0Qkjij0RUTiiLlH93efzGwrsK6Ui58EbDvuXNGhKvUKVavfqtQrVK1+q1KvULX6LWuvrd29ceFi1Id+WZhZprtnRLqPkqhKvULV6rcq9QpVq9+q1CtUrX4rqlcN74iIxBGFvohIHIn10B8d6QZOQFXqFapWv1WpV6ha/ValXqFq9Vshvcb0mL6IiPxYrG/pi4hIGIW+iEgcicnQN7PBZrbczLKC6/RGLTNraWbTzOxrM/vKzO6JdE/HY2aJZrbQzN6JdC/HY2b1zexVM/sm+Dc+K9I9FcfMfhX8Diw1s0lmVr4XRy0jMxtjZlvMbGlYraGZfWRmK4P7BpHsMVwx/T4a/C4sMbM3zKx+BFs8oqhew577tZm5mZ1UHu8Vc6FvZonA34ALgI7ANcFF2aNVHnC/u58G9ALuiPJ+Ae4Bvo50EyX0JPBvdz+V0Flho7JvM0sF7gYy3L0TkAhcHdmujjIOGFyo9ltgqru3B6YGj6PFOI7u9yOgk7ufCawAHqrspooxjqN7xcxaAucSOk19uYi50Cd09a4sd1/t7oeAVwhdlD0quftmd/8imN5NKJRO4OrtlcvM0oCLgBci3cvxmFldoD/wIoC7H3L37yPa1LElATXMLAmoSZRdWc7dZwI7CpUvA8YH0+OByyuzp2Mpql93/9Dd84KHcwhdwS/iivm3BXgceIBiLi1bGrEY+qnAhrDHxV6APdqYWTrQFZgb4VaO5QlCv4QFEe6jJNoCW4GxwXDUC8FV3qKOu28EHiO0RbcZ2OnuH0a2qxJp6u6bIbQBAzSJcD8n4mbg/Ug3URwzuxTY6O6Ly/N1YzH0S3wB9mhiZrWB14B73X1XpPspipldDGxx9wWR7qWEkoBuwLPu3hXYS3QNPxwRjIVfBrQBWgC1zOz6yHYVu8zsYUJDqy9HupeimFlN4GHg9+X92rEY+lXuAuxmlkwo8F9299cj3c8x9AEuNbO1hIbNzjazf0S2pWPKBrLd/fAnp1cJrQSi0U+BNe6+1d1zgdeB3hHuqSRyzKw5QHC/JcL9HJeZDQMuBq7z6P2i0smENgAWB39vacAXZtasrC8ci6E/H2hvZm3MLIXQzrC3ItxTsSx0keEXga/d/a+R7udY3P0hd09z93RC/66fuHvUbo26+7fABjPrEJTOIXSN5mi0HuhlZjWD34lziNKdzoW8BQwLpocBUyLYy3GZ2WDgQeBSd98X6X6K4+5funsTd08P/t6ygW7B73SZxFzoBztp7gQ+IPRHM9ndv4psV8fUB7iB0FbzouB2YaSbiiF3AS+b2RKgC/A/kW2naMGnkVeBL4AvCf1tRtUpA8xsEjAb6GBm2WY2HPgTcK6ZrSR0lMmfItljuGL6fRqoA3wU/K09F9EmA8X0WjHvFb2fbkREpLzF3Ja+iIgUT6EvIhJHFPoiInFEoS8iEkcU+iIicUShLyISRxT6IiJx5P8DBjfbPAvZaBoAAAAASUVORK5CYII=\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.plot(np.arange(len(history_ppl_train)), history_ppl_train, history_ppl_valid)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Inferencja" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'Gości innych nie widział oprócz spółleśników'" ] }, "execution_count": 39, "metadata": {}, "output_type": "execute_result" } ], "source": [ "'Gości innych nie widział oprócz spółleśników'" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [], "source": [ "tokenized = list(tokenize('Gości innych nie widział oprócz spółleśników',lowercase = True))" ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [], "source": [ "#tokenized = tokenized[-NGRAMS :-1 ]" ] }, { "cell_type": "code", "execution_count": 42, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['gości', 'innych', 'nie', 'widział', 'oprócz', 'spółleśników']" ] }, "execution_count": 42, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tokenized" ] }, { "cell_type": "code", "execution_count": 43, "metadata": {}, "outputs": [], "source": [ "ids = []\n", "for word in tokenized:\n", " if word in vocab_stoi:\n", " ids.append(vocab_stoi[word])\n", " else:\n", " ids.append(vocab_stoi[''])" ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[2671, 3168, 5873, 13240, 6938, 15001]" ] }, "execution_count": 44, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ids" ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "LSTM(\n", " (emb): Embedding(15005, 100)\n", " (rec): LSTM(100, 256, batch_first=True)\n", " (fc1): Linear(in_features=256, out_features=15005, bias=True)\n", ")" ] }, "execution_count": 45, "metadata": {}, "output_type": "execute_result" } ], "source": [ "lm.eval()" ] }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [], "source": [ "ids = torch.tensor(ids, dtype = torch.long, device = device)" ] }, { "cell_type": "code", "execution_count": 47, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([ 2671, 3168, 5873, 13240, 6938, 15001], device='cuda:0')" ] }, "execution_count": 47, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ids" ] }, { "cell_type": "code", "execution_count": 48, "metadata": {}, "outputs": [], "source": [ "preds= lm(ids.unsqueeze(0))" ] }, { "cell_type": "code", "execution_count": 49, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "15001" ] }, "execution_count": 49, "metadata": {}, "output_type": "execute_result" } ], "source": [ "torch.argmax(torch.softmax(preds,1),1).item()" ] }, { "cell_type": "code", "execution_count": 50, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.return_types.max(\n", "values=tensor([0.1419], device='cuda:0', grad_fn=),\n", "indices=tensor([15001], device='cuda:0'))" ] }, "execution_count": 50, "metadata": {}, "output_type": "execute_result" } ], "source": [ "torch.max(torch.softmax(preds,1),1)" ] }, { "cell_type": "code", "execution_count": 51, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "''" ] }, "execution_count": 51, "metadata": {}, "output_type": "execute_result" } ], "source": [ "vocab_itos[torch.argmax(torch.softmax(preds,1),1).item()]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## ZADANIE: GENEROWANIE TEKSTU\n", "\n", "Napisać funkcję generującą tekst, która dla podanego fragmentu generuje tekst.\n", "Generowanie tekstu ma wyglądać następująco: Z 10 najbardziej prawodpodobnych tokenów należy wylosować jeden, ala ma to byc token inny niż specjalny (UNK, BOS, EOS, PAD). \n", "\n", "wygenerować tekst o długości 30 tokenów" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" } }, "nbformat": 4, "nbformat_minor": 4 }