From e001c96e47315dc1d3cfbebba9b41cebf23e8484 Mon Sep 17 00:00:00 2001 From: kubapok Date: Sat, 29 May 2021 22:02:38 +0200 Subject: [PATCH] init --- lstm - ODPOWIEDZI.ipynb | 1465 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 1465 insertions(+) create mode 100644 lstm - ODPOWIEDZI.ipynb diff --git a/lstm - ODPOWIEDZI.ipynb b/lstm - ODPOWIEDZI.ipynb new file mode 100644 index 0000000..5e571ae --- /dev/null +++ b/lstm - ODPOWIEDZI.ipynb @@ -0,0 +1,1465 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## importy" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/media/kuba/ssdsam/anaconda3/lib/python3.8/site-packages/gensim/similarities/__init__.py:15: UserWarning: The gensim.similarities.levenshtein submodule is disabled, because the optional Levenshtein package is unavailable. Install Levenhstein (e.g. `pip install python-Levenshtein`) to suppress this warning.\n", + " warnings.warn(msg)\n" + ] + } + ], + "source": [ + "from gensim.utils import tokenize\n", + "import numpy as np\n", + "import torch\n", + "from tqdm.notebook import tqdm" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "#device = 'cpu'" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Using cuda device\n" + ] + } + ], + "source": [ + "print('Using {} device'.format(device))" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "device(type='cuda')" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "device" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## przygotowanie zbiorów" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "pan_tadeusz_path_train= '/home/kuba/Syncthing/przedmioty/2020-02/ISI/zajecia9_ngramowy_model_jDDezykowy/pan-tadeusz-train.txt'" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "pan_tadeusz_path_valid= '/home/kuba/Syncthing/przedmioty/2020-02/ISI/zajecia9_ngramowy_model_jDDezykowy/pan-tadeusz-test.txt'" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "corpora_train = open(pan_tadeusz_path_train).read()" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "corpora_train_tokenized = list(tokenize(corpora_train,lowercase = True))" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "vocab_itos = sorted(set(corpora_train_tokenized))" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "16598" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(vocab_itos)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [], + "source": [ + "vocab_itos = vocab_itos[:15005]\n", + "vocab_itos[15001] = \"\"\n", + "vocab_itos[15002] = \"\"\n", + "vocab_itos[15003] = \"\"\n", + "vocab_itos[15004] = \"\"" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "15005" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "len(vocab_itos)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "vocab_stoi = dict()\n", + "for i, token in enumerate(vocab_itos):\n", + " vocab_stoi[token] = i" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "NGRAMS = 5" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [], + "source": [ + "def get_token_id(dataset):\n", + " token_ids = [vocab_stoi['']] * (NGRAMS-1) + [vocab_stoi['']]\n", + " for token in dataset:\n", + " try:\n", + " token_ids.append(vocab_stoi[token])\n", + " except KeyError:\n", + " token_ids.append(vocab_stoi[''])\n", + " token_ids.append(vocab_stoi[''])\n", + " return token_ids" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "train_ids = get_token_id(corpora_train_tokenized)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "text/plain": [ + "[15004,\n", + " 15004,\n", + " 15004,\n", + " 15004,\n", + " 15002,\n", + " 7,\n", + " 5002,\n", + " 7247,\n", + " 11955,\n", + " 1432,\n", + " 7018,\n", + " 14739,\n", + " 5506,\n", + " 4696,\n", + " 4276,\n", + " 7505,\n", + " 2642,\n", + " 8477,\n", + " 7259,\n", + " 10870,\n", + " 10530,\n", + " 7506,\n", + " 12968,\n", + " 7997,\n", + " 1911,\n", + " 12479,\n", + " 11129,\n", + " 13069,\n", + " 11797,\n", + " 5819]" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train_ids[:30]" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "def get_samples(dataset):\n", + " samples = []\n", + " for i in range(len(dataset)-NGRAMS):\n", + " samples.append(dataset[i:i+NGRAMS])\n", + " return samples" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "train_ids = get_samples(train_ids)" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "train_ids = torch.tensor(train_ids, device = device)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([[15004, 15004, 15004, 15004, 15002],\n", + " [15004, 15004, 15004, 15002, 7],\n", + " [15004, 15004, 15002, 7, 5002],\n", + " [15004, 15002, 7, 5002, 7247],\n", + " [15002, 7, 5002, 7247, 11955],\n", + " [ 7, 5002, 7247, 11955, 1432],\n", + " [ 5002, 7247, 11955, 1432, 7018],\n", + " [ 7247, 11955, 1432, 7018, 14739],\n", + " [11955, 1432, 7018, 14739, 5506],\n", + " [ 1432, 7018, 14739, 5506, 4696],\n", + " [ 7018, 14739, 5506, 4696, 4276],\n", + " [14739, 5506, 4696, 4276, 7505],\n", + " [ 5506, 4696, 4276, 7505, 2642],\n", + " [ 4696, 4276, 7505, 2642, 8477],\n", + " [ 4276, 7505, 2642, 8477, 7259],\n", + " [ 7505, 2642, 8477, 7259, 10870],\n", + " [ 2642, 8477, 7259, 10870, 10530],\n", + " [ 8477, 7259, 10870, 10530, 7506],\n", + " [ 7259, 10870, 10530, 7506, 12968],\n", + " [10870, 10530, 7506, 12968, 7997],\n", + " [10530, 7506, 12968, 7997, 1911],\n", + " [ 7506, 12968, 7997, 1911, 12479],\n", + " [12968, 7997, 1911, 12479, 11129],\n", + " [ 7997, 1911, 12479, 11129, 13069],\n", + " [ 1911, 12479, 11129, 13069, 11797],\n", + " [12479, 11129, 13069, 11797, 5819],\n", + " [11129, 13069, 11797, 5819, 6268],\n", + " [13069, 11797, 5819, 6268, 2807],\n", + " [11797, 5819, 6268, 2807, 7831],\n", + " [ 5819, 6268, 2807, 7831, 12893]], device='cuda:0')" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train_ids[:30]" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.Size([57022, 5])" + ] + }, + "execution_count": 22, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train_ids.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "corpora_valid = open(pan_tadeusz_path_valid).read()" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [], + "source": [ + "corpora_valid_tokenized = list(tokenize(corpora_valid,lowercase = True))" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [], + "source": [ + "valid_ids = get_token_id(corpora_valid_tokenized)" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [], + "source": [ + "valid_ids = torch.tensor(get_samples(valid_ids), dtype = torch.long, device = device)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## model" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [], + "source": [ + "# https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [], + "source": [ + "class LSTM(torch.nn.Module):\n", + "\n", + " def __init__(self):\n", + " super(LSTM, self).__init__()\n", + " self.emb = torch.nn.Embedding(len(vocab_itos),100)\n", + " self.rec = torch.nn.LSTM(100, 256, 1, batch_first = True)\n", + " self.fc1 = torch.nn.Linear( 256 ,len(vocab_itos))\n", + " #self.dropout = torch.nn.Dropout(0.5)\n", + "\n", + " def forward(self, x):\n", + " emb = self.emb(x)\n", + " #emb = self.dropout(emb)\n", + " output, (h_n, c_n) = self.rec(emb)\n", + " hidden = h_n.squeeze(0)\n", + " out = self.fc1(hidden)\n", + " #out = self.dropout(out)\n", + " return out" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [], + "source": [ + "lm = LSTM().to(device)" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [], + "source": [ + "criterion = torch.nn.CrossEntropyLoss()" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [], + "source": [ + "optimizer = torch.optim.Adam(lm.parameters(),lr=0.0001)" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [], + "source": [ + "BATCH_SIZE = 128\n", + "EPOCHS = 15" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [], + "source": [ + "def get_ppl(dataset_ids):\n", + " lm.eval()\n", + "\n", + " batches = 0\n", + " loss_sum =0\n", + " acc_score = 0\n", + "\n", + " for i in range(0, len(dataset_ids)-BATCH_SIZE+1, BATCH_SIZE):\n", + " X = dataset_ids[i:i+BATCH_SIZE,:NGRAMS-1]\n", + " Y = dataset_ids[i:i+BATCH_SIZE,NGRAMS-1]\n", + " predictions = lm(X)\n", + " \n", + " # equally distributted\n", + " # predictions = torch.zeros_like(predictions)\n", + " \n", + " loss = criterion(predictions,Y)\n", + "\n", + " loss_sum += loss.item()\n", + " batches += 1\n", + "\n", + " return np.exp(loss_sum / batches)" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "610f61c2bfcb4102af04aa8964ebb8a3", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=445.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "epoch: 0\n", + "train ppl: 2287.3500820554295\n", + "valid ppl: 531.0113829517392\n", + "\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "6785c17bb88440109c7591d235f063d4", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=445.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "epoch: 1\n", + "train ppl: 2082.7357174891326\n", + "valid ppl: 516.542261379181\n", + "\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "4a800dbbef5d4a98871d7dfe1ae71cef", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=445.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "epoch: 2\n", + "train ppl: 1998.148471220956\n", + "valid ppl: 510.99021596013944\n", + "\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "c2aec3dd6759434fabb56c5b91bcfe9c", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=445.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "epoch: 3\n", + "train ppl: 1913.740628231292\n", + "valid ppl: 508.91670819123317\n", + "\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "06ac6c8ea5db49d3ad35125aeab8826d", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=445.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "epoch: 4\n", + "train ppl: 1817.7626005392221\n", + "valid ppl: 509.691281725011\n", + "\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "58b7d3ce5bc1458c960e055fa3fa938a", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=445.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "epoch: 5\n", + "train ppl: 1708.5886654297572\n", + "valid ppl: 509.5967005513094\n", + "\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "7bce3902c5a54be7af646fb9da6eb536", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=445.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "epoch: 6\n", + "train ppl: 1590.5836574012103\n", + "valid ppl: 510.39878889794727\n", + "\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "f3eae4a0767d4cdaa65df064bd52d058", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=445.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "epoch: 7\n", + "train ppl: 1471.1897079151051\n", + "valid ppl: 510.97901616528486\n", + "\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "0eaef289e94242d4bbb3e78721b7dc3b", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=445.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "epoch: 8\n", + "train ppl: 1354.9050103780992\n", + "valid ppl: 511.1877020218083\n", + "\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "03d0dc64bbef4893ad0d0047807c0d39", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=445.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "epoch: 9\n", + "train ppl: 1243.0030151749731\n", + "valid ppl: 511.01325486962844\n", + "\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "0fccd1d763af49849caa497326e74bed", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=445.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "epoch: 10\n", + "train ppl: 1135.9030031936575\n", + "valid ppl: 511.9617587007096\n", + "\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "65afb8b0bf3d46d9bc5c567c22889071", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=445.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "epoch: 11\n", + "train ppl: 1034.278462241148\n", + "valid ppl: 514.2758177779934\n", + "\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "915c092f72e1471a91202be1b371ba38", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=445.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "epoch: 12\n", + "train ppl: 938.3817179142794\n", + "valid ppl: 517.9274067739209\n", + "\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "726ec81fde69442d987d39adcece9da2", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=445.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "epoch: 13\n", + "train ppl: 849.432099748702\n", + "valid ppl: 522.2558997769627\n", + "\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "5758908c8e2243f4b91f67684ed296c0", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, max=445.0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "epoch: 14\n", + "train ppl: 767.7593036002053\n", + "valid ppl: 527.0581693919813\n", + "\n" + ] + } + ], + "source": [ + "history_ppl_train = []\n", + "history_ppl_valid = []\n", + "for epoch in range(EPOCHS):\n", + " \n", + " batches = 0\n", + " loss_sum =0\n", + " acc_score = 0\n", + " lm.train()\n", + " #for i in range(0, len(train_ids)-BATCH_SIZE+1, BATCH_SIZE):\n", + " for i in tqdm(range(0, len(train_ids)-BATCH_SIZE+1, BATCH_SIZE)):\n", + " X = train_ids[i:i+BATCH_SIZE,:NGRAMS-1]\n", + " Y = train_ids[i:i+BATCH_SIZE,NGRAMS-1]\n", + " predictions = lm(X)\n", + " loss = criterion(predictions,Y)\n", + " \n", + " \n", + " \n", + " optimizer.zero_grad()\n", + " loss.backward()\n", + " optimizer.step()\n", + " \n", + " loss_sum += loss.item()\n", + " batches += 1\n", + " \n", + " ppl_train = get_ppl(train_ids)\n", + " ppl_valid = get_ppl(valid_ids)\n", + " \n", + " history_ppl_train.append(ppl_train)\n", + " history_ppl_valid.append(ppl_valid)\n", + " \n", + " print('epoch: ', epoch)\n", + " print('train ppl: ', ppl_train)\n", + " print('valid ppl: ', ppl_valid)\n", + " print()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## parametry modelu" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "data": { + "text/plain": [ + "[Parameter containing:\n", + " tensor([[-0.4749, -1.2938, 0.0550, ..., 0.2910, 0.6960, 0.1963],\n", + " [-0.4820, 0.0798, -0.3073, ..., -0.3827, -0.2542, -1.5678],\n", + " [-0.7655, -0.7346, -0.3758, ..., -0.6114, -0.0162, -0.8364],\n", + " ...,\n", + " [-0.3525, 1.0935, -1.2036, ..., 0.0430, 0.0183, -0.6422],\n", + " [-0.2081, -1.1451, 0.5068, ..., 2.8449, -0.8814, -1.0899],\n", + " [-0.5691, 0.3699, 0.0096, ..., 1.0125, -2.3366, 0.3840]],\n", + " device='cuda:0', requires_grad=True),\n", + " Parameter containing:\n", + " tensor([[ 0.0279, -0.0695, 0.1370, ..., 0.0381, 0.0495, -0.0814],\n", + " [ 0.0943, 0.0669, 0.0204, ..., -0.0343, -0.0033, -0.0528],\n", + " [ 0.0070, -0.0610, 0.0476, ..., 0.0744, -0.0443, 0.0575],\n", + " ...,\n", + " [-0.0458, -0.0248, 0.0011, ..., -0.0125, -0.0303, 0.0601],\n", + " [ 0.0176, -0.1003, -0.0006, ..., -0.0623, -0.0228, -0.0785],\n", + " [ 0.0718, -0.0176, 0.0415, ..., -0.0435, 0.0486, 0.1307]],\n", + " device='cuda:0', requires_grad=True),\n", + " Parameter containing:\n", + " tensor([[ 0.1083, 0.0527, -0.1201, ..., -0.0371, 0.0479, 0.1017],\n", + " [ 0.0464, 0.0560, 0.0180, ..., 0.0050, 0.0446, 0.0313],\n", + " [ 0.0829, 0.0184, -0.0865, ..., -0.0577, 0.0473, 0.0306],\n", + " ...,\n", + " [ 0.0323, 0.0655, -0.1150, ..., -0.0961, 0.0824, 0.0394],\n", + " [ 0.0355, 0.0544, -0.0869, ..., 0.0256, 0.0507, -0.0386],\n", + " [ 0.0396, 0.0834, -0.0596, ..., -0.0664, 0.0457, 0.0711]],\n", + " device='cuda:0', requires_grad=True),\n", + " Parameter containing:\n", + " tensor([0.1295, 0.0468, 0.1477, ..., 0.0589, 0.1052, 0.0321], device='cuda:0',\n", + " requires_grad=True),\n", + " Parameter containing:\n", + " tensor([0.1343, 0.0621, 0.0573, ..., 0.1319, 0.0422, 0.1092], device='cuda:0',\n", + " requires_grad=True),\n", + " Parameter containing:\n", + " tensor([[-0.0523, 0.0240, -0.0816, ..., -0.0907, 0.0489, 0.0188],\n", + " [-0.0491, 0.0112, 0.0373, ..., 0.0562, -0.0190, 0.0237],\n", + " [ 0.0174, -0.0335, 0.0115, ..., 0.0150, -0.0653, -0.0523],\n", + " ...,\n", + " [-0.0210, -0.0424, -0.0148, ..., -0.0392, -0.1321, -0.0233],\n", + " [-0.1408, -0.0953, 0.1749, ..., 0.1256, -0.1097, -0.0778],\n", + " [-0.1257, -0.1036, 0.0855, ..., 0.0856, -0.1131, -0.0770]],\n", + " device='cuda:0', requires_grad=True),\n", + " Parameter containing:\n", + " tensor([ 0.0278, -0.0064, 0.0545, ..., 0.0186, -0.0600, 0.0036],\n", + " device='cuda:0', requires_grad=True)]" + ] + }, + "execution_count": 36, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "list(lm.parameters())" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### krzywe uczenia" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "text/plain": [ + "[531.0113829517392,\n", + " 516.542261379181,\n", + " 510.99021596013944,\n", + " 508.91670819123317,\n", + " 509.691281725011,\n", + " 509.5967005513094,\n", + " 510.39878889794727,\n", + " 510.97901616528486,\n", + " 511.1877020218083,\n", + " 511.01325486962844,\n", + " 511.9617587007096,\n", + " 514.2758177779934,\n", + " 517.9274067739209,\n", + " 522.2558997769627,\n", + " 527.0581693919813]" + ] + }, + "execution_count": 37, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "history_ppl_valid" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": {}, + "outputs": [], + "source": [ + "import matplotlib.pyplot as plt" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[,\n", + " ]" + ] + }, + "execution_count": 39, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD4CAYAAAAAczaOAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAApgElEQVR4nO3deXxU5dn/8c+VFcK+BERCSEBAARUkIIss4oaWqtXWHRBU3Fr3Vm37tP11eepT16KPCxVFLC4U16cVRVA2ZTEgi4hA2ANhXwRZk1y/P+agY0wgZGEmme/79ZrXnLnPdgWS75y5z5lzm7sjIiKxIS7SBYiIyPGj0BcRiSEKfRGRGKLQFxGJIQp9EZEYkhDpAo6mcePGnpGREekyRESqlLlz525199Si7VEf+hkZGWRnZ0e6DBGRKsXM1hTXru4dEZEYotAXEYkhCn0RkRii0BcRiSEKfRGRGKLQFxGJIQp9EZEYUm1D/5XZa5m2bEukyxARiSrVMvQP5hcydvYahr+czeyV2yJdjohI1Dhq6JtZCzP72MyWmNliM7szaH/YzL4ys4Vm9paZ1Q/aM8xsn5nNDx7Phm2ri5ktMrMcMxthZlYZP1RSQhxjhnUjrUEKw0Z/xvx1OytjNyIiVU5pjvTzgXvd/RSgO3C7mbUHPgQ6uvtpwDLgwbB1Vrh7p+BxS1j7M8BwoE3wGFARP0RxGtVO5p83nEmj2skMHjWbLzd8XVm7EhGpMo4a+u6e5+7zgundwBKgubtPdPf8YLFZQNqRtmNmzYC67j7TQ2M0jgEuLU/xR3NCvRqMvfFMaicnMGjUbHI276nM3YmIRL1j6tM3swygMzC7yKxhwISw15lm9rmZTTWz3kFbcyA3bJncoK24/Qw3s2wzy96ypXwnY1s0TOGfN56JmXHt87NYu21vubYnIlKVlTr0zaw28AZwl7t/Hdb+G0JdQGODpjwg3d07A/cAr5hZXaC4/vtiR2V395HunuXuWampP7gz6DFrlVqbsTeeyYH8Qq55fhZ5u/aVe5siIlVRqULfzBIJBf5Yd38zrH0IMBC4Nuiywd0PuPu2YHousAJoS+jIPrwLKA3YUBE/RGm0O6EOLw87k117D3HtP2azZfeB47VrEZGoUZqrdwwYBSxx98fC2gcA9wMXu/vesPZUM4sPplsROmG70t3zgN1m1j3Y5mDgnQr9aY7i1LR6vDi0K3m79jNo1Gx2fHPweO5eRCTiSnOk3wsYBPQPuwzzIuApoA7wYZFLM/sAC81sATAeuMXdtwfzbgWeB3IIfQIIPw9wXGRlNOT5IVms3PoNQ16cw9f7Dx3vEkREIsaCXpmolZWV5ZUxctZHX21i+Ji5dE6vz0vDupGSFPWDiImIlJqZzXX3rKLt1fIbuaXR/+Sm/P2qzsxds4PhY+ay/1BBpEsSEal0MRv6AD86rRkP//R0ZuRs5eevzONQQWGkSxIRqVQxHfoAl3dJ40+XdmTSks3c9fp8Cgqju7tLRKQ81JENDOrekv0HC/jLe0uomRjP3y4/jbi4SrktkIhIRCn0Azf1acXegwU8PmkZNRPj+eMlHaik+8GJiESMQj/MHeecxN5D+Tw3dSUpSfE8cOHJCn4RqVYU+mHMjAcGnMy+gwU8N20lKUkJ3Hlum0iXJSJSYRT6RZgZf/hxh2+7elKS4rmpT6tIlyUiUiEU+sWIizP+5/LT2H8odHK3RlI8g7q3jHRZIiLlptAvQXyc8fiVndh/qID/evsLUhLjubzLEYcMEBGJejF/nf6RJMbH8dQ1Z9C7TWN+OX4B/1mYF+mSRETKRaF/FDUS43luUBe6tGzAna99zuQlmyJdkohImSn0SyElKYEXru9K+xPrcuvYeTw9JUe3ZRaRKkmhX0p1aiQyZlg3urdqxN/eX0qPhybz4JsLWbpxd6RLExEptZi9tXJ5fLXxa176dDVvzlvPgfxCerZuxNBemfQ/uQnxun2DiESBkm6trNAvhx3fHOTVz9by8sw15O3aT3rDFAb3aMkVXVtQt0ZipMsTkRim0K9E+QWFfLB4Ey9+sorsNTuolRTPT7ukMaRnBq1Sa0e6PBGJQQr942RR7i5e/HQV/16Qx8GCQvq1S2Vor0x6n9RYd+4UkeNGoX+cbdl9gLGz1/DPWWvZuucArVNrcX3PDC47I41ayfpOnIhUrjIPl2hmLczsYzNbYmaLzezOoL2hmX1oZsuD5wZh6zxoZjlmttTMLghr72Jmi4J5I6wa38IytU4yd53blk8f6M/jV55OreQE/uudxXT/62T+/O8vWbd9b6RLFJEYdNQjfTNrBjRz93lmVgeYC1wKXA9sd/eHzOwBoIG7329m7YFXgW7AicAkoK27F5jZHOBOYBbwHjDC3Sccaf9V9Ui/KHdn3tqdvPjJKiZ8sRF359xTmnJ9rwx6tGqkWziLSIUq6Uj/qP0M7p4H5AXTu81sCdAcuAToFyz2EjAFuD9of83dDwCrzCwH6GZmq4G67j4zKGgMoTePI4Z+dWFmdGnZgC4tG5C3ax//nLWGV2avZeKXmzj5hDoM7ZXBJZ2aUyMxPtKlikg1dkxfzjKzDKAzMBtoGrwhHH5jaBIs1hxYF7ZabtDWPJgu2l7cfoabWbaZZW/ZsuVYSqwSmtWryS8vOJmZD57D/1x+KgD3v7GIHn+dzN/e/4q8XfsiXKGIVFelPqNoZrWBN4C73P3rI3RHFDfDj9D+w0b3kcBICHXvlLbGqqZGYjxXdk3niqwWzFq5nRc/WcWzU1fw3LSVDOh4AkN7ZtClZQN1/YhIhSlV6JtZIqHAH+vubwbNm8ysmbvnBf3+m4P2XKBF2OppwIagPa2Y9phnZvRo3YgerRuxbvteXp61htfmrOU/C/M4tXk9ru+ZwcDTm5GcoK4fESmf0ly9Y8AoYIm7PxY2611gSDA9BHgnrP0qM0s2s0ygDTAn6ALabWbdg20ODltHAi0apvDri05h1q/P4c+XdmTfoQLu/dcCej30EY99uIzNu/dHukQRqcJKc/XOWcB0YBFQGDT/mlC//jggHVgL/Mzdtwfr/AYYBuQT6g6aELRnAaOBmoRO4P7Cj1JAdbl6p6zcnRk5W3nxk9V89NVmEuONH53ajKG9Mjm9Rf1IlyciUUpfzqoGVm39hpc+Xc34ubnsOZBP5/T6DO2VyYUdTyAxXjdMFZHvKPSrkd37DzF+bi4vfbqa1dv20rRuMoO6t+Tqbuk0qp0c6fJEJAoo9KuhwkJnyrLNvPjJaqYv30pSQhyXnH4i1/fKoMOJ9SJdnohEUJm/nCXRKy7O6H9yU/qf3JTlm3bz0szVvDF3Pf+am0u3zIYM7ZnBee2bkqCuHxEJ6Ei/mtm19xDjstfx0szV5O7YR1qDmtzW7yR+2iWNpASFv0isUPdOjCkodD78chPPTF3BgnU7ObFeDW7p15orslroVg8iMUChH6PcnenLtzJi8nKy1+ygSZ1kbu7bmmu6pVMzSeEvUl0p9GOcuzNz5TZGTF7OrJXbaVw7ieF9WnHtmS11f3+RakihL9+as2o7T360nOnLt9IgJZEbe7dicI+W1NG4viLVhkJffmDumh08+dFypizdQr2aiQzrlcn1vTKoV1PhL1LVKfSlRAtzdzJicg6TlmyiTnIC1/fKYFivTBrUSop0aSJSRgp9OarFG3bx1Ec5TPhiI7WS4hnUI4ObemfqW74iVZBCX0pt6cbdPPVxDv9euIEaCfFc1z2dm/q0okmdGpEuTURKSaEvxyxn8x6e/jiHt+evJzE+jqu7pXNL39acUE/hLxLtFPpSZqu3fsPTU3J4c9564sy4omsat/U7iRPr14x0aSJSAoW+lNu67Xt5ZuoK/pW9jjgzbuydya39TqK2rvMXiToKfakwuTv28vAHS3ln/gYa107i7vPacmVWC93YTSSKlBT6+iuVY5bWIIW/X9WZt2/vRWbjWvzmrS+48O/T+XjpZqL9IEIk1in0pcw6tajPuJt78Ox1XThUUMjQFz9j8AtzWJL3daRLE5ESKPSlXMyMAR1PYOLdffndwPYszN3FRSOmc//4hWz+WoO4i0Sbo4a+mb1gZpvN7IuwttfNbH7wWG1m84P2DDPbFzbv2bB1upjZIjPLMbMRZmaV8hNJRCQlxDHsrEym/fJsbuiVyZuf59LvkSn8fdJy9h7Mj3R5IhIozZH+aGBAeIO7X+nundy9E/AG8GbY7BWH57n7LWHtzwDDgTbB43vblOqhXkoivx3Ynkn39KVfu1Qen7SMsx+ZwrjsdRQUqr9fJNKOGvruPg3YXty84Gj9CuDVI23DzJoBdd19pofO9I0BLj3maqXKaNmoFk9f24Xxt/SgWb2a/Gr8QgY+OYNPcrZGujSRmFbePv3ewCZ3Xx7Wlmlmn5vZVDPrHbQ1B3LDlskN2oplZsPNLNvMsrds2VLOEiWSsjIa8tZtPRlxdWe+3neIa5+fzbDRn5GzeXekSxOJSeUN/av5/lF+HpDu7p2Be4BXzKwuUFz/fYmf9d19pLtnuXtWampqOUuUSDMzLj79RCbf25cHLjyZz1Zt54InpvPbtxexdc+BSJcnElPKHPpmlgBcBrx+uM3dD7j7tmB6LrACaEvoyD4tbPU0YENZ9y1VU43EeG7p25opv+zHtWem8+qcdfR7eApPT8lh/6GCSJcnEhPKc6R/LvCVu3/bbWNmqWYWH0y3InTCdqW75wG7zax7cB5gMPBOOfYtVVij2sn88ZKOTLy7D91bNeJv7y/lnEen8vbn6ynUyV6RSlWaSzZfBWYC7cws18xuCGZdxQ9P4PYBFprZAmA8cIu7Hz4JfCvwPJBD6BPAhAqoX6qw1qm1eX5IFq/cdCb1UxK56/X5XPHcTL7coC93iVQW3XtHokJhoTN+bi4Pvf8VO/ceZHCPDO45vy11NW6vSJno3jsS1eLijCu6tuCje/tyzZnpvDRzNf0fmcpbn+fqfj4iFUihL1GlfkoSf770VN69/SyaN6jJ3a8v4MqRs1i6UZd4ilQEhb5EpVPT6vHWrT3562WnsmzTbi4aMZ0///tL9hzQLR1EykOhL1ErLs64uls6H93bjyuy0nh+xirOeXQK7y7YoC4fkTJS6EvUa1grib9edhpv3daT1DrJ3PHq51z7/Gx9q1ekDBT6UmV0Tm/AO7efxZ8u7cgX63cx4Inp/HXCEr5Rl49IqSn0pUqJjzMGdW/JR/f149LOzXlu6krOfWwq7y3KU5ePSCko9KVKalw7mUd+djrjb+lB/ZQkbhs7j8EvzGHllj2RLk0kqin0pUrLymjI//28F7//cXvmr93JgCem8/AHX7HvoO7lI1Ichb5UeQnxcQztlcnk+/ryo9Oa8b8fr+Dcx6byweKN6vIRKUKhL9VGkzo1ePzKTrw+vDu1kuO5+eW5DBv9GWu2fRPp0kSihkJfqp0zWzXiP3f05rc/OoU5q7Zz3uPTeGLSMt2+WQSFvlRTifFx3Ni7FZPv7cf57ZvyxKTlXPDENKYu00hsEtsU+lKtnVCvBk9dcwYv39CNODOGvDCH28bOZeOu/ZEuTSQiFPoSE3q3SeX9u3pz73ltmbxkM+c8OoXnp68kv6Aw0qWJHFcKfYkZyQnx/OKcNnx4d1+6Zjbkz/9ZwsAnZ5C9evvRVxapJhT6EnPSG6Xw4vVdefa6Luzad4ifPjuTX41fwPZvDka6NJFKp9CXmGRmDOh4ApPu6cvNfVrx5rz19H90Cq/NWatxeqVaU+hLTKuVnMCDF53Cf+7oTdsmdXjgzUVc/uynLN6wK9KliVSK0gyM/oKZbTazL8La/mBm681sfvC4KGzeg2aWY2ZLzeyCsPYuZrYomDfCzKzifxyRsml3Qh1ev7k7j/7sdNZu28uPn5zB//u/xezefyjSpYlUqNIc6Y8GBhTT/ri7dwoe7wGYWXvgKqBDsM7TZhYfLP8MMBxoEzyK26ZIxJgZl3dJY/K9fbm6WzqjP13NOY9O5f80aItUI0cNfXefBpT28oZLgNfc/YC7rwJygG5m1gyo6+4zPfTXMwa4tIw1i1Sq+ilJ/OUnp/LWbb1oUjeZX7z6OYNG6Q6eUj2Up0//52a2MOj+aRC0NQfWhS2TG7Q1D6aLthfLzIabWbaZZW/Zom9QSmR0alGfd24/iz9e0oEF60J38Hxs4lLdzkGqtLKG/jNAa6ATkAc8GrQX10/vR2gvlruPdPcsd89KTU0tY4ki5RcfZwzukfHtHTxHfJTDeY9P5eOvNke6NJEyKVPou/smdy9w90LgH0C3YFYu0CJs0TRgQ9CeVky7SJVw+A6er9x0JknxcQwd/RnDx2SzbvveSJcmckzKFPpBH/1hPwEOX9nzLnCVmSWbWSahE7Zz3D0P2G1m3YOrdgYD75SjbpGI6Nm6MRPu7MMvL2jH9OVbOfexqYyYvFxdPlJlJBxtATN7FegHNDazXOD3QD8z60Soi2Y1cDOAuy82s3HAl0A+cLu7H/5ruJXQlUA1gQnBQ6TKSUqI4/azT+InnZvzl/eW8NiHyxg/N5ffDWzPue2bRro8kSOyaL8ULSsry7OzsyNdhkiJPsnZyu/fXUzO5j30P7kJvxvYnozGtSJdlsQ4M5vr7llF2/WNXJFy6nVSYybc+d2gLec/Po1HPljK3oP5kS5N5AcU+iIV4PCgLR/dG7rK56mPczj30alMWJSnL3ZJVFHoi1SgJnVDV/mMu7kHdWsmcuvYeQwaNYeczbsjXZoIoNAXqRTdMhvy71+Evti1MDf0xa7/fm8Jew6oy0ciS6EvUkkS4uMY3CODj+7rx+VnpDFy2kr6PzKFd+avV5ePRIxCX6SSNa6dzP/89DTevr0XJ9SrwZ2vzefK52axJO/rSJcmMUihL3KcdGpRn7dv68VDl53K8s27GfjkDP7w7mJ27dPtm+X4UeiLHEdxccZV3dL5+L5+XNMtnTEzV9P/kSmMy16nEbvkuFDoi0RA/ZQk/nRpR979+VlkNK7Fr8Yv5LJnPmVh7s5IlybVnEJfJII6Nq/H+Ft68NgVp5O7Yx+X/O8nPPDGQrbsPhDp0qSaUuiLRJiZcdkZaXx0X19u6JXJ+Lm5nP3IFJ6buoID+bqRm1Qshb5IlKhbI5HfDmzPxLv7cGZmQ/464SvOf3waHyzeqEs8pcIo9EWiTKvU2oy6vitjhnUjOSGOm1+eyzX/mK1LPKVCKPRFolSftqm8d0dv/nRJB77a+DU/GjGdB99cxNY96u+XslPoi0SxhPg4BvXIYMp9Z3N9z0z+lb2Osx+ewshp6u+XslHoi1QB9VIS+d2P2/PB3X3omtmQ/34v1N8/Uf39cowU+iJVSOvU2rxwfVdeGtaNpPg4hr88l+tGqb9fSk+hL1IF9W2byoQ7e/PHSzqweEOov//Xby1im/r75SiOGvpm9oKZbTazL8LaHjazr8xsoZm9ZWb1g/YMM9tnZvODx7Nh63Qxs0VmlmNmI4IB0kWkjA7fxXPKff0Y0jODcZ+to9/DU/jHtJUczC+MdHkSpUpzpD8aGFCk7UOgo7ufBiwDHgybt8LdOwWPW8LanwGGA22CR9FtikgZ1E9J4vc/7sD7d/UhK6MBf3lvCec/PpUPv9yk/n75gaOGvrtPA7YXaZvo7odHg5gFpB1pG2bWDKjr7jM99Fs4Bri0TBWLSLFOalKbF4d2Y/TQriTEx3HTmGwGjZrDVxvV3y/fqYg+/WHAhLDXmWb2uZlNNbPeQVtzIDdsmdygTUQqWL92TZhwZ2/+38Ud+GLDLi76+3R+85au75eQhPKsbGa/AfKBsUFTHpDu7tvMrAvwtpl1AIrrvy/xc6eZDSfUFUR6enp5ShSJSYnxcQzpmcElnU7kiUnLeXnWGt6Zv4Fb+rbihrNaUTMpPtIlSoSU+UjfzIYAA4Frgy4b3P2Au28LpucCK4C2hI7sw7uA0oANJW3b3Ue6e5a7Z6Wmppa1RJGYVz8liT9c3IGJd/eh10mNeGTiMvo98jGvf7aWAt2/PyaVKfTNbABwP3Cxu+8Na081s/hguhWhE7Yr3T0P2G1m3YOrdgYD75S7ehEpldaptXluUBbjb+lB8/o1uf+NRVz492l89JVO9saa0lyy+SowE2hnZrlmdgPwFFAH+LDIpZl9gIVmtgAYD9zi7odPAt8KPA/kEPoEEH4eQESOg6yMhrxxa0+eve4MDhU4w0Znc80/Zmvwlhhi0f4un5WV5dnZ2ZEuQ6TaOVRQyKtz1vL3ScvZ9s1BLj79RH55QTtaNEyJdGlSAcxsrrtn/aBdoS8S23bvP8TIaSv5x/SVFBbCoB4t+UX/k6ifkhTp0qQcFPoickQbd+3niUnLGJe9jtrJCdx+9kkM6ZlBjURd6VMVlRT6uveOiABwQr0aPHT5abx/Vx+6ZoRG7jrn0am8OS+XQl3pU20o9EXke9o2rcOo67vyyk1n0rBWEveMW8DAJ2cwY/nWSJcmFUChLyLF6tm6Me/c3osRV3dm94FDXDdqNoNfmMOXG3Rbh6pMoS8iJYqLMy4+/UQm3dOX3/7oFBas28mPnpzOveMWsGHnvkiXJ2WgE7kiUmq79h7i6ak5vPjJagwY2iuTW/q20pU+UUhX74hIhVm/cx+PfrCUt+avp3ZyAjf1bsWwszKpnVyu23lJBVLoi0iF+2rj1zw2cRkTv9xEw1pJ3Nq3NYN6tNRlnlFAoS8ilWbBup08MnEp05dvpWndZH7evw1XZrUgKUGnDSNFoS8ilW72ym08MnEpn63eQYuGNbnznLb8pHNz4uM0Ourxpi9niUilO7NVI8bd3IPRQ7tSv2YS9/1rAec/PpX/LMzTF7yihEJfRCqUmdGvXRPe/Xkvnr3uDOLMuP2VeQx8coZu5RwFFPoiUinMjAEdm/H+XX14/MrT2XMgn2Gjs/npszP5dIW+3Rsp6tMXkePiUEEh/8rOZcTk5Wz8ej+9TmrEfee3o3N6g0iXVi3pRK6IRIX9hwoYO3stT3+cw7ZvDnLuKU2457x2tD+xbqRLq1YU+iISVb45kM+Ln6ziuWkr2b0/n4GnNePu89rSOrV2pEurFhT6IhKVdu09xD+mr+SFT1ax/1ABl5+Rxh3ntNEIXuWk0BeRqLZ1zwGembKCl2etobDQufyMNG47uzUtG9WKdGlVUpmv0zezF8xss5l9EdbW0Mw+NLPlwXODsHkPmlmOmS01swvC2ruY2aJg3ggz07c1RORbjWsn818D2zP1l/24rntL3pq/nv6PTuWecfNZsWVPpMurNkpzyeZoYECRtgeAye7eBpgcvMbM2gNXAR2CdZ42s8M34XgGGA60CR5FtykiQrN6NfnDxR2Y8auzGdozg/cW5XHeY1O549XPWbZpd6TLq/KOGvruPg3YXqT5EuClYPol4NKw9tfc/YC7rwJygG5m1gyo6+4zPdSfNCZsHRGRH2hStwa/HdieGff356Y+rZi0ZBMXPDGN28bO1UAu5VDWL2c1dfc8gOC5SdDeHFgXtlxu0NY8mC7aXiwzG25m2WaWvWXLljKWKCLVQePayTx44SnMuL8/t/c7ienLtnLRiOncNCabRbm7Il1elVPR38gtrp/ej9BeLHcf6e5Z7p6VmppaYcWJSNXVsFYS913Qjhn39+euc9swe+U2fvzUDIa+OId5a3dEurwqo6yhvynosiF43hy05wItwpZLAzYE7WnFtIuIHJN6KYncdW5bPnmgP7+8oB3z1+3ksqc/ZdCo2Xy2umhPtBRV1tB/FxgSTA8B3glrv8rMks0sk9AJ2zlBF9BuM+seXLUzOGwdEZFjVqdGIreffRIz7u/PgxeezJK8r/nZszO5amTo3j7Rfjl6pBz1On0zexXoBzQGNgG/B94GxgHpwFrgZ+6+PVj+N8AwIB+4y90nBO1ZhK4EqglMAH7hpfhf0XX6IlIa+w4W8MqctTw3dQWbdx+ga0YDftG/Db3bNCYWrxDXl7NEJCbsP1TAuOx1PDNlBXm79tOpRX3uOOckzm7XJKbCX6EvIjHlQH4Bb8xdz9NTcsjdsY+OzetyW7+TuKDDCTExkpdCX0Ri0qGCQt76fD1Pf5zD6m17SW+Ywo29M/lZlxbUTKq+A7gr9EUkphUUOh9+uZHnpq3k87U7aZCSyKDuLRncM4PGtZMjXV6FU+iLiADuztw1O3hu2komLdlEYnwcl5+Rxo29M6vVbZ1LCv2ESBQjIhIpZkZWRkOyMhqyYsseRs1Yxfi5ubz22VrOPaUpw/u0Iqtlg2p70ldH+iIS87buOcCYmWt4eeZqduw9ROf0+gzv3Yrzq/BJX3XviIgcxb6DBYyfu47nZ6xizba9tGyUwo1nZfLTKnjSV6EvIlJKBYXOxMWhk77z1wUnfXtkMLhHyypz0lehLyJyjNyd7DU7GFnkpO9NvTNpFeUnfXUiV0TkGJkZXTMa0jU46fv89FW8Me+7k74392lFlyp20ldH+iIix2DrngOM+XQ1Y2atYWdw0vf6nhlc2LEZSQkVfbf6slP3johIBdp7MJ835uYyasYqVm/bS+PayVzdrQXXnJlOs3o1I12eQl9EpDIUFjrTlm/h5Zlr+GjpZuLMOO+Upgzq0ZKerRtFrOtHffoiIpUgLs7o164J/do1Yd32vfxz9hrGfbaO9xdvpHVqLQZ1b8nlXdKoUyMx0qUCOtIXEalw+w8V8J+FeYyZtYYF63aSkhTPTzo3Z3CPDNqdUOe41KDuHRGRCFiYu5MxM9fw7oINHMwvpFtmQwb3aMkFHU4gMb7yTvwq9EVEImjHNwcZl72Of85ew7rt+2hSJ5mru6VzzZnpNK1bo8L3p9AXEYkCBYXO1GWbGTNzDVOXbSHOjAs6NGVQ9wy6t2pYYSd+dSJXRCQKxMcZ/U9uSv+Tm7Jm2zf8c9YaxmXn8t6ijbRtWptB3VvykzPSqJ1cOfFc5iN9M2sHvB7W1Ar4HVAfuAnYErT/2t3fC9Z5ELgBKADucPcPjrYfHemLSHW372AB/7dgA2NmreaL9V9TOzmBy85ozj3ntaV+SlKZtlmp3TtmFg+sB84EhgJ73P2RIsu0B14FugEnApOAtu5ecKRtK/RFJFa4O/PX7eTlmWuYuXIbH9/XjxqJZbu7Z2V375wDrHD3NUfoj7oEeM3dDwCrzCyH0BvAzAqqQUSkSjMzOqc3oHN6Aw7mF1bKbR0qaotXETqKP+znZrbQzF4wswZBW3NgXdgyuUHbD5jZcDPLNrPsLVu2FLeIiEi1Vln38Sn3Vs0sCbgY+FfQ9AzQGugE5AGPHl60mNWL7Vty95HunuXuWampqeUtUUREAhXxVnIhMM/dNwG4+yZ3L3D3QuAfhLpwIHRk3yJsvTRgQwXsX0RESqkiQv9qwrp2zKxZ2LyfAF8E0+8CV5lZspllAm2AORWwfxERKaVyncg1sxTgPODmsOa/mVknQl03qw/Pc/fFZjYO+BLIB24/2pU7IiJSscoV+u6+F2hUpG3QEZb/C/CX8uxTRETKLnqGeRERkUqn0BcRiSEKfRGRGKLQFxGJIQp9EZEYotAXEYkhCn0RkRii0BcRiSEKfRGRGKLQFxGJIQp9EZEYotAXEYkhCn0RkRii0BcRiSEKfRGRGKLQFxGJIQp9EZEYUq6Rs6LaqumAQ82GkNIw9JxYI9JViYhEVPUN/f/cA1uXfb8tMSV4E2jw/TeDYp8bhJ6T60GcPhCJVHuFhVBwMPQozP9uuuBQ8DgIhWHT4e3fvj5Y5HEI8g8UmX/gu+n8IssWFFn29jkQn1ihP2Z5B0ZfDewGCoB8d88ys4bA60AGoYHRr3D3HcHyDwI3BMvf4e4flGf/R3TFy/DNZti7HfZtD553fP/1xkWh5/07wQtL+CHjoEb9778hJNaE+GRICB7xScF0je+m45NCr783nRSsF7w+PH342eIhLj70bHHBdByYVdo/U0xwDz1CL4Jp/24epZhf0rLh+zjS/r/fcJT5xfjB74AdZf4RavFCvvdzFttWWGS+Fz/fC8ELoLAgNF1YEPa6PO35oeArPBRqOzxdkB+0HZ4f9hy+TkGw3rfT+d/fRnHhXlIGlJsFOZAcCvD4pNDz4WyIT/xuXmL9IC+SQs+VUFNFHOmf7e5bw14/AEx294fM7IHg9f1m1h64CugAnAhMMrO27l5QATX8UJOTgZNLt2xhIRzYVfwbQ9HnrzdA/v7gEbxr5wePSvpRwL57M/j2jSA+9Ank2+nwN4u476Ytjh/8gRf7DMX+sRf3XJb6S+1ItZUQPBVaq0Qli4O4xFAwxiWEHvGJQVtC6Dku4bvpw8slJH+3XFx8ELKHgzYI4LiEsPaw6dK2x4UHeNj2D7fFxUf6X+97KqN75xKgXzD9EjAFuD9of83dDwCrzCwH6AbMrIQajk1cXKg7p2aD8m2nsCB4A9gffHQL3gwKDoTeIPL3fzcd/mZxuC38SMcLQm9G3zuSOnwUVBg2HX6E5D/chhcCFhwJHuXZ4oJpSrFOKZXmKPa7hY+h1lIsU9wzhLUd3m9J863I/JKWDdtOuFLPO0ZH++RQ4jL2/U+O3/6ccRT/e1DS/CLLhh90xBV9Xdb2+LCwTlAXawUqb+g7MNHMHHjO3UcCTd09D8Dd88ysSbBsc2BW2Lq5QdsPmNlwYDhAenp6OUs8juLiISkl9BARiULlDf1e7r4hCPYPzeyrIyxb3KFNsYeBwZvHSICsrCx9RhcRqSDl+szk7huC583AW4S6azaZWTOA4HlzsHgu0CJs9TRgQ3n2LyIix6bMoW9mtcyszuFp4HzgC+BdYEiw2BDgnWD6XeAqM0s2s0ygDTCnrPsXEZFjV57unabAWxY6KZQAvOLu75vZZ8A4M7sBWAv8DMDdF5vZOOBLIB+4vdKu3BERkWKVOfTdfSVwejHt24BzSljnL8BfyrpPEREpH10HJSISQxT6IiIxRKEvIhJDzI/pG5PHn5ltAdaUcfXGwNajLhUdqlKtULXqrUq1QtWqtyrVClWr3vLW2tLdU4s2Rn3ol4eZZbt7VqTrKI2qVCtUrXqrUq1QteqtSrVC1aq3smpV946ISAxR6IuIxJDqHvojI13AMahKtULVqrcq1QpVq96qVCtUrXorpdZq3acvIiLfV92P9EVEJIxCX0QkhlTL0DezAWa21MxygiEbo5aZtTCzj81siZktNrM7I13T0ZhZvJl9bmb/jnQtR2Nm9c1svJl9Ffwb94h0TSUxs7uD34EvzOxVM6sR6ZrCmdkLZrbZzL4Ia2toZh+a2fLguZzDz1WcEup9OPhdWGhmb5lZ/QiW+K3iag2bd5+ZuZk1roh9VbvQN7N44H+BC4H2wNXB+LzRKh+4191PAboDt0d5vQB3AksiXUQp/R14391PJnSDwKis28yaA3cAWe7eEYgnNKZ0NBkNDCjSdnhM7DbA5OB1tBjND+v9EOjo7qcBy4AHj3dRJRjND2vFzFoA5xG6Y3GFqHahT2gglxx3X+nuB4HXCI3PG5XcPc/d5wXTuwmFUrHDSEYDM0sDfgQ8H+lajsbM6gJ9gFEA7n7Q3XdGtKgjSwBqmlkCkEKUDTLk7tOA7UWaLyE0FjbB86XHs6YjKa5ed5/o7vnBy1mEBnOKuBL+bQEeB35FCaMMlkV1DP3mwLqw1yWOxRttzCwD6AzMjnApR/IEoV/CwgjXURqtgC3Ai0F31PPBgD9Rx93XA48QOqLLA3a5+8TIVlUq3xsTG2hylOWjyTBgQqSLKImZXQysd/cFFbnd6hj6pR6LN5qYWW3gDeAud/860vUUx8wGApvdfW6kaymlBOAM4Bl37wx8Q3R1P3wr6Au/BMgETgRqmdl1ka2q+jKz3xDqWh0b6VqKY2YpwG+A31X0tqtj6Fe5sXjNLJFQ4I919zcjXc8R9AIuNrPVhLrN+pvZPyNb0hHlArnufviT03hCbwLR6FxglbtvcfdDwJtAzwjXVBoljYkdtcxsCDAQuNaj94tKrQkdACwI/t7SgHlmdkJ5N1wdQ/8zoI2ZZZpZEqGTYe9GuKYSWWi8yVHAEnd/LNL1HIm7P+juae6eQejf9SN3j9qjUXffCKwzs3ZB0zmEhuuMRmuB7maWEvxOnEOUnnQuoqQxsaOSmQ0A7gcudve9ka6nJO6+yN2buHtG8PeWC5wR/E6XS7UL/eAkzc+BDwj90Yxz98WRreqIegGDCB01zw8eF0W6qGrkF8BYM1sIdAL+O7LlFC/4NDIemAcsIvS3GVW3DDCzV4GZQDszyw3GwX4IOM/MlhO6yuShSNYYroR6nwLqAB8Gf2vPRrTIQAm1Vs6+ovfTjYiIVLRqd6QvIiIlU+iLiMQQhb6ISAxR6IuIxBCFvohIDFHoi4jEEIW+iEgM+f8Fe/INo1reYwAAAABJRU5ErkJggg==\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "plt.plot(np.arange(len(history_ppl_train)), history_ppl_train, history_ppl_valid)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Inferencja" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'Gości innych nie widział oprócz spółleśników'" + ] + }, + "execution_count": 40, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "'Gości innych nie widział oprócz spółleśników'" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [], + "source": [ + "tokenized = list(tokenize('Gości innych nie widział oprócz spółleśników',lowercase = True))" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": {}, + "outputs": [], + "source": [ + "#tokenized = tokenized[-NGRAMS :-1 ]" + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['gości', 'innych', 'nie', 'widział', 'oprócz', 'spółleśników']" + ] + }, + "execution_count": 43, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tokenized" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "metadata": {}, + "outputs": [], + "source": [ + "ids = []\n", + "for word in tokenized:\n", + " if word in vocab_stoi:\n", + " ids.append(vocab_stoi[word])\n", + " else:\n", + " ids.append(vocab_stoi[''])" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[2671, 3168, 5873, 13240, 6938, 15001]" + ] + }, + "execution_count": 45, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ids" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "LSTM(\n", + " (emb): Embedding(15005, 100)\n", + " (rec): LSTM(100, 256, batch_first=True)\n", + " (fc1): Linear(in_features=256, out_features=15005, bias=True)\n", + ")" + ] + }, + "execution_count": 46, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "lm.eval()" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "metadata": {}, + "outputs": [], + "source": [ + "ids = torch.tensor(ids, dtype = torch.long, device = device)" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "tensor([ 2671, 3168, 5873, 13240, 6938, 15001], device='cuda:0')" + ] + }, + "execution_count": 48, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ids" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "metadata": {}, + "outputs": [], + "source": [ + "preds= lm(ids.unsqueeze(0))" + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "15001" + ] + }, + "execution_count": 50, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "torch.argmax(torch.softmax(preds,1),1).item()" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "torch.return_types.max(\n", + "values=tensor([0.1040], device='cuda:0', grad_fn=),\n", + "indices=tensor([15001], device='cuda:0'))" + ] + }, + "execution_count": 51, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "torch.max(torch.softmax(preds,1),1)" + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "''" + ] + }, + "execution_count": 52, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "vocab_itos[torch.argmax(torch.softmax(preds,1),1).item()]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## ZADANIE: GENEROWANIE TEKSTU\n", + "\n", + "Napisać funkcję generującą tekst, która dla podanego fragmentu generuje tekst.\n", + "Generowanie tekstu ma wyglądać następująco: Z 10 najbardziej prawodpodobnych tokenów należy wylosować jeden, ala ma to byc token inny niż specjalny (UNK, BOS, EOS, PAD). \n", + "\n", + "Wygenerować tekst o długości 30 tokenówm" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### generowanie tekstu" + ] + }, + { + "cell_type": "code", + "execution_count": 59, + "metadata": {}, + "outputs": [], + "source": [ + "tokenized = list(tokenize('Pan Tadeusz', lowercase = True))" + ] + }, + { + "cell_type": "code", + "execution_count": 60, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "['pan', 'tadeusz']" + ] + }, + "execution_count": 60, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "tokenized" + ] + }, + { + "cell_type": "code", + "execution_count": 61, + "metadata": {}, + "outputs": [], + "source": [ + "ids = []\n", + "for word in tokenized:\n", + " if word in vocab_stoi:\n", + " ids.append(vocab_stoi[word])\n", + " else:\n", + " ids.append(vocab_stoi[''])\n", + "ids = torch.tensor([ids], dtype = torch.long, device = device)" + ] + }, + { + "cell_type": "code", + "execution_count": 62, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "nie\n", + "ma\n", + "z\n", + "nim\n", + "w\n", + "tym\n", + "w\n", + "nim\n", + "w\n", + "w\n", + "nie\n", + "a\n", + "i\n", + "z\n", + "tak\n", + "z\n", + "w\n", + "ręku\n", + "w\n", + "z\n", + "już\n", + "na\n", + "a\n", + "to\n", + "i\n", + "tak\n", + "nie\n", + "w\n", + "z\n", + "nim\n" + ] + } + ], + "source": [ + "candidates_number = 10\n", + "for i in range(30):\n", + " preds= lm(ids)\n", + " candidates = torch.topk(torch.softmax(preds,1),candidates_number)[1][0].cpu().numpy()\n", + " candidate = 15001\n", + " while candidate > 15000:\n", + " candidate = candidates[np.random.randint(candidates_number)]\n", + " print(vocab_itos[candidate])\n", + " ids = torch.cat((ids, torch.tensor([[candidate]], device = device)), 1)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +}