{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## importy" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/media/kuba/ssdsam/anaconda3/lib/python3.8/site-packages/gensim/similarities/__init__.py:15: UserWarning: The gensim.similarities.levenshtein submodule is disabled, because the optional Levenshtein package is unavailable. Install Levenhstein (e.g. `pip install python-Levenshtein`) to suppress this warning.\n", " warnings.warn(msg)\n" ] } ], "source": [ "from gensim.utils import tokenize\n", "import numpy as np\n", "import torch\n", "from tqdm.notebook import tqdm" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "#device = 'cpu'" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using cuda device\n" ] } ], "source": [ "print('Using {} device'.format(device))" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "device(type='cuda')" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "device" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## przygotowanie zbiorów" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "pan_tadeusz_path_train= '/home/kuba/Syncthing/przedmioty/2020-02/ISI/zajecia9_ngramowy_model_jDDezykowy/pan-tadeusz-train.txt'" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "pan_tadeusz_path_valid= '/home/kuba/Syncthing/przedmioty/2020-02/ISI/zajecia9_ngramowy_model_jDDezykowy/pan-tadeusz-test.txt'" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "corpora_train = open(pan_tadeusz_path_train).read()" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "corpora_train_tokenized = list(tokenize(corpora_train,lowercase = True))" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "vocab_itos = sorted(set(corpora_train_tokenized))" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "16598" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(vocab_itos)" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "vocab_itos = vocab_itos[:15005]\n", "vocab_itos[15001] = \"\"\n", "vocab_itos[15002] = \"\"\n", "vocab_itos[15003] = \"\"\n", "vocab_itos[15004] = \"\"" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "15005" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(vocab_itos)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "vocab_stoi = dict()\n", "for i, token in enumerate(vocab_itos):\n", " vocab_stoi[token] = i" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "NGRAMS = 5" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "def get_token_id(dataset):\n", " token_ids = [vocab_stoi['']] * (NGRAMS-1) + [vocab_stoi['']]\n", " for token in dataset:\n", " try:\n", " token_ids.append(vocab_stoi[token])\n", " except KeyError:\n", " token_ids.append(vocab_stoi[''])\n", " token_ids.append(vocab_stoi[''])\n", " return token_ids" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [], "source": [ "train_ids = get_token_id(corpora_train_tokenized)" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "[15004,\n", " 15004,\n", " 15004,\n", " 15004,\n", " 15002,\n", " 7,\n", " 5002,\n", " 7247,\n", " 11955,\n", " 1432,\n", " 7018,\n", " 14739,\n", " 5506,\n", " 4696,\n", " 4276,\n", " 7505,\n", " 2642,\n", " 8477,\n", " 7259,\n", " 10870,\n", " 10530,\n", " 7506,\n", " 12968,\n", " 7997,\n", " 1911,\n", " 12479,\n", " 11129,\n", " 13069,\n", " 11797,\n", " 5819]" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_ids[:30]" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "def get_samples(dataset):\n", " samples = []\n", " for i in range(len(dataset)-NGRAMS):\n", " samples.append(dataset[i:i+NGRAMS])\n", " return samples" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "train_ids = get_samples(train_ids)" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [], "source": [ "train_ids = torch.tensor(train_ids, device = device)" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "tensor([[15004, 15004, 15004, 15004, 15002],\n", " [15004, 15004, 15004, 15002, 7],\n", " [15004, 15004, 15002, 7, 5002],\n", " [15004, 15002, 7, 5002, 7247],\n", " [15002, 7, 5002, 7247, 11955],\n", " [ 7, 5002, 7247, 11955, 1432],\n", " [ 5002, 7247, 11955, 1432, 7018],\n", " [ 7247, 11955, 1432, 7018, 14739],\n", " [11955, 1432, 7018, 14739, 5506],\n", " [ 1432, 7018, 14739, 5506, 4696],\n", " [ 7018, 14739, 5506, 4696, 4276],\n", " [14739, 5506, 4696, 4276, 7505],\n", " [ 5506, 4696, 4276, 7505, 2642],\n", " [ 4696, 4276, 7505, 2642, 8477],\n", " [ 4276, 7505, 2642, 8477, 7259],\n", " [ 7505, 2642, 8477, 7259, 10870],\n", " [ 2642, 8477, 7259, 10870, 10530],\n", " [ 8477, 7259, 10870, 10530, 7506],\n", " [ 7259, 10870, 10530, 7506, 12968],\n", " [10870, 10530, 7506, 12968, 7997],\n", " [10530, 7506, 12968, 7997, 1911],\n", " [ 7506, 12968, 7997, 1911, 12479],\n", " [12968, 7997, 1911, 12479, 11129],\n", " [ 7997, 1911, 12479, 11129, 13069],\n", " [ 1911, 12479, 11129, 13069, 11797],\n", " [12479, 11129, 13069, 11797, 5819],\n", " [11129, 13069, 11797, 5819, 6268],\n", " [13069, 11797, 5819, 6268, 2807],\n", " [11797, 5819, 6268, 2807, 7831],\n", " [ 5819, 6268, 2807, 7831, 12893]], device='cuda:0')" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_ids[:30]" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.Size([57022, 5])" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "train_ids.shape" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "corpora_valid = open(pan_tadeusz_path_valid).read()" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [], "source": [ "corpora_valid_tokenized = list(tokenize(corpora_valid,lowercase = True))" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "valid_ids = get_token_id(corpora_valid_tokenized)" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [], "source": [ "valid_ids = torch.tensor(get_samples(valid_ids), dtype = torch.long, device = device)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## model" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [], "source": [ "# https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [], "source": [ "class LSTM(torch.nn.Module):\n", "\n", " def __init__(self):\n", " super(LSTM, self).__init__()\n", " self.emb = torch.nn.Embedding(len(vocab_itos),100)\n", " self.rec = torch.nn.LSTM(100, 256, 1, batch_first = True)\n", " self.fc1 = torch.nn.Linear( 256 ,len(vocab_itos))\n", " #self.dropout = torch.nn.Dropout(0.5)\n", "\n", " def forward(self, x):\n", " emb = self.emb(x)\n", " #emb = self.dropout(emb)\n", " output, (h_n, c_n) = self.rec(emb)\n", " hidden = h_n.squeeze(0)\n", " out = self.fc1(hidden)\n", " #out = self.dropout(out)\n", " return out" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [], "source": [ "lm = LSTM().to(device)" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [], "source": [ "criterion = torch.nn.CrossEntropyLoss()" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [], "source": [ "optimizer = torch.optim.Adam(lm.parameters(),lr=0.0001)" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [], "source": [ "BATCH_SIZE = 128\n", "EPOCHS = 15" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [], "source": [ "def get_ppl(dataset_ids):\n", " lm.eval()\n", "\n", " batches = 0\n", " loss_sum =0\n", " acc_score = 0\n", "\n", " for i in range(0, len(dataset_ids)-BATCH_SIZE+1, BATCH_SIZE):\n", " X = dataset_ids[i:i+BATCH_SIZE,:NGRAMS-1]\n", " Y = dataset_ids[i:i+BATCH_SIZE,NGRAMS-1]\n", " predictions = lm(X)\n", " \n", " # equally distributted\n", " # predictions = torch.zeros_like(predictions)\n", " \n", " loss = criterion(predictions,Y)\n", "\n", " loss_sum += loss.item()\n", " batches += 1\n", "\n", " return np.exp(loss_sum / batches)" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "2c8985983fcd4ec5b93ba32b0345d000", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=445.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "epoch: 0\n", "train ppl: 2296.6914856482526\n", "valid ppl: 528.9542436139727\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "d2acde1a6e19416cb1eaee8f40f0114f", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=445.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "epoch: 1\n", "train ppl: 2093.302103954666\n", "valid ppl: 514.4726844027333\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "c9df3b66aede47febbe6852ca62eed17", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=445.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "epoch: 2\n", "train ppl: 2014.09679023559\n", "valid ppl: 510.12146471773366\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "d5272c8b25f74c5a9383a829ed30308e", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=445.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "epoch: 3\n", "train ppl: 1939.0594855086504\n", "valid ppl: 509.1060151440451\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "d6656f2e796b44b7b03d204ecd89f89b", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=445.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "epoch: 4\n", "train ppl: 1854.4566511885196\n", "valid ppl: 510.02244291272973\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "86419de761c1443c86c70726099b9838", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=445.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "epoch: 5\n", "train ppl: 1755.030202547313\n", "valid ppl: 508.494174178397\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "5d12b6fcfc0045c9b8bac58343036dff", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=445.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "epoch: 6\n", "train ppl: 1646.180912657662\n", "valid ppl: 506.06383737670035\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "d6593e6387d5439a9c1301a366fecda6", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=445.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "epoch: 7\n", "train ppl: 1533.0501876139222\n", "valid ppl: 504.08067276707567\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "32742824eb00458b8586795a7fb56d84", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=445.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "epoch: 8\n", "train ppl: 1420.680717507558\n", "valid ppl: 502.6906095632547\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "5ee78d1ebd03415885d4ea2663c0ee7f", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=445.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "epoch: 9\n", "train ppl: 1311.1083504083306\n", "valid ppl: 503.5230045363773\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "b30185826c4341b088934df92839852c", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=445.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "epoch: 10\n", "train ppl: 1203.498635587493\n", "valid ppl: 505.7599916969862\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "7eeeeb3ef93a45f3a6709a8579164c79", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=445.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "epoch: 11\n", "train ppl: 1100.0681613054269\n", "valid ppl: 507.6071195979723\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "919fd04ac9ed476e931149c15dc460db", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=445.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "epoch: 12\n", "train ppl: 1003.217414775517\n", "valid ppl: 510.07952767103245\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "a30dc2c9bd194ead94be7de3275a1e5e", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=445.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "epoch: 13\n", "train ppl: 912.2987798296267\n", "valid ppl: 512.8275727599236\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "94f2bac40c83416fad59742f2e2adff4", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=445.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "epoch: 14\n", "train ppl: 826.911431868259\n", "valid ppl: 516.1525759633064\n", "\n" ] } ], "source": [ "history_ppl_train = []\n", "history_ppl_valid = []\n", "for epoch in range(EPOCHS):\n", " \n", " batches = 0\n", " loss_sum =0\n", " acc_score = 0\n", " lm.train()\n", " #for i in range(0, len(train_ids)-BATCH_SIZE+1, BATCH_SIZE):\n", " for i in tqdm(range(0, len(train_ids)-BATCH_SIZE+1, BATCH_SIZE)):\n", " X = train_ids[i:i+BATCH_SIZE,:NGRAMS-1]\n", " Y = train_ids[i:i+BATCH_SIZE,NGRAMS-1]\n", " predictions = lm(X)\n", " loss = criterion(predictions,Y)\n", " \n", " \n", " \n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", " \n", " loss_sum += loss.item()\n", " batches += 1\n", " \n", " ppl_train = get_ppl(train_ids)\n", " ppl_valid = get_ppl(valid_ids)\n", " \n", " history_ppl_train.append(ppl_train)\n", " history_ppl_valid.append(ppl_valid)\n", " \n", " print('epoch: ', epoch)\n", " print('train ppl: ', ppl_train)\n", " print('valid ppl: ', ppl_valid)\n", " print()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## parametry modelu" ] }, { "cell_type": "code", "execution_count": 35, "metadata": { "scrolled": false }, "outputs": [ { "data": { "text/plain": [ "[Parameter containing:\n", " tensor([[ 0.4949, -1.2472, -0.7167, ..., -0.0801, -2.1905, 0.9790],\n", " [ 0.2070, 1.7394, 0.8255, ..., -0.5796, -0.3776, 0.8831],\n", " [-1.4559, 1.1073, -0.4904, ..., -1.2919, -2.2661, -0.5476],\n", " ...,\n", " [-0.3706, 0.2133, 0.0484, ..., -0.5792, -0.5769, -0.6941],\n", " [ 0.5502, -0.1212, -2.0879, ..., 0.6764, -0.5961, -0.6282],\n", " [ 0.8362, 0.2193, -0.0807, ..., 2.7741, -0.2589, 0.3310]],\n", " device='cuda:0', requires_grad=True),\n", " Parameter containing:\n", " tensor([[ 0.0252, -0.0744, 0.0817, ..., -0.0559, 0.0896, 0.0208],\n", " [ 0.0423, 0.0329, -0.0610, ..., -0.0009, 0.0169, -0.0361],\n", " [ 0.0507, -0.0838, 0.0520, ..., 0.0395, 0.0067, 0.0173],\n", " ...,\n", " [ 0.0669, 0.0430, -0.0306, ..., 0.0096, 0.0619, -0.0992],\n", " [-0.0153, -0.0888, 0.0580, ..., -0.0433, 0.0399, -0.0494],\n", " [-0.0067, -0.0053, -0.0242, ..., 0.0017, -0.0306, -0.0972]],\n", " device='cuda:0', requires_grad=True),\n", " Parameter containing:\n", " tensor([[ 0.0212, -0.0425, -0.0329, ..., -0.0206, 0.0839, 0.0286],\n", " [ 0.0952, 0.0298, -0.1211, ..., -0.0468, -0.0233, -0.0620],\n", " [ 0.0108, 0.0422, -0.0492, ..., -0.0288, -0.0231, 0.0078],\n", " ...,\n", " [ 0.0253, 0.0154, -0.0765, ..., -0.0025, 0.0057, -0.0408],\n", " [ 0.0892, -0.0928, -0.1039, ..., -0.1531, -0.0011, 0.0180],\n", " [ 0.1341, 0.0666, -0.0548, ..., -0.0573, -0.0376, -0.0813]],\n", " device='cuda:0', requires_grad=True),\n", " Parameter containing:\n", " tensor([0.1115, 0.0423, 0.1438, ..., 0.0361, 0.0297, 0.0956], device='cuda:0',\n", " requires_grad=True),\n", " Parameter containing:\n", " tensor([0.0786, 0.1201, 0.0857, ..., 0.1177, 0.1319, 0.0886], device='cuda:0',\n", " requires_grad=True),\n", " Parameter containing:\n", " tensor([[-0.0158, 0.0236, -0.0958, ..., -0.0906, -0.0678, 0.0057],\n", " [-0.0871, -0.0788, 0.1217, ..., -0.0231, -0.0102, 0.0220],\n", " [ 0.0265, -0.0680, -0.0219, ..., -0.0520, -0.0565, 0.0628],\n", " ...,\n", " [-0.0618, 0.0232, 0.0898, ..., 0.1069, -0.0112, 0.0103],\n", " [-0.0489, 0.0708, 0.0546, ..., 0.1186, -0.0987, 0.1411],\n", " [-0.0764, 0.0463, 0.0947, ..., 0.1104, -0.0312, 0.1118]],\n", " device='cuda:0', requires_grad=True),\n", " Parameter containing:\n", " tensor([ 0.0299, -0.0551, -0.0323, ..., -0.0371, -0.0297, -0.0157],\n", " device='cuda:0', requires_grad=True)]" ] }, "execution_count": 35, "metadata": {}, "output_type": "execute_result" } ], "source": [ "list(lm.parameters())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### krzywe uczenia" ] }, { "cell_type": "code", "execution_count": 36, "metadata": { "scrolled": true }, "outputs": [ { "data": { "text/plain": [ "[528.9542436139727,\n", " 514.4726844027333,\n", " 510.12146471773366,\n", " 509.1060151440451,\n", " 510.02244291272973,\n", " 508.494174178397,\n", " 506.06383737670035,\n", " 504.08067276707567,\n", " 502.6906095632547,\n", " 503.5230045363773,\n", " 505.7599916969862,\n", " 507.6071195979723,\n", " 510.07952767103245,\n", " 512.8275727599236,\n", " 516.1525759633064]" ] }, "execution_count": 36, "metadata": {}, "output_type": "execute_result" } ], "source": [ "history_ppl_valid" ] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[,\n", " ]" ] }, "execution_count": 38, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "plt.plot(np.arange(len(history_ppl_train)), history_ppl_train, history_ppl_valid)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Inferencja" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'Gości innych nie widział oprócz spółleśników'" ] }, "execution_count": 39, "metadata": {}, "output_type": "execute_result" } ], "source": [ "'Gości innych nie widział oprócz spółleśników'" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [], "source": [ "tokenized = list(tokenize('Gości innych nie widział oprócz spółleśników',lowercase = True))" ] }, { "cell_type": "code", "execution_count": 41, "metadata": {}, "outputs": [], "source": [ "#tokenized = tokenized[-NGRAMS :-1 ]" ] }, { "cell_type": "code", "execution_count": 42, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['gości', 'innych', 'nie', 'widział', 'oprócz', 'spółleśników']" ] }, "execution_count": 42, "metadata": {}, "output_type": "execute_result" } ], "source": [ "tokenized" ] }, { "cell_type": "code", "execution_count": 43, "metadata": {}, "outputs": [], "source": [ "ids = []\n", "for word in tokenized:\n", " if word in vocab_stoi:\n", " ids.append(vocab_stoi[word])\n", " else:\n", " ids.append(vocab_stoi[''])" ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "[2671, 3168, 5873, 13240, 6938, 15001]" ] }, "execution_count": 44, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ids" ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "LSTM(\n", " (emb): Embedding(15005, 100)\n", " (rec): LSTM(100, 256, batch_first=True)\n", " (fc1): Linear(in_features=256, out_features=15005, bias=True)\n", ")" ] }, "execution_count": 45, "metadata": {}, "output_type": "execute_result" } ], "source": [ "lm.eval()" ] }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [], "source": [ "ids = torch.tensor(ids, dtype = torch.long, device = device)" ] }, { "cell_type": "code", "execution_count": 47, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensor([ 2671, 3168, 5873, 13240, 6938, 15001], device='cuda:0')" ] }, "execution_count": 47, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ids" ] }, { "cell_type": "code", "execution_count": 48, "metadata": {}, "outputs": [], "source": [ "preds= lm(ids.unsqueeze(0))" ] }, { "cell_type": "code", "execution_count": 49, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "15001" ] }, "execution_count": 49, "metadata": {}, "output_type": "execute_result" } ], "source": [ "torch.argmax(torch.softmax(preds,1),1).item()" ] }, { "cell_type": "code", "execution_count": 50, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "torch.return_types.max(\n", "values=tensor([0.1419], device='cuda:0', grad_fn=),\n", "indices=tensor([15001], device='cuda:0'))" ] }, "execution_count": 50, "metadata": {}, "output_type": "execute_result" } ], "source": [ "torch.max(torch.softmax(preds,1),1)" ] }, { "cell_type": "code", "execution_count": 51, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "''" ] }, "execution_count": 51, "metadata": {}, "output_type": "execute_result" } ], "source": [ "vocab_itos[torch.argmax(torch.softmax(preds,1),1).item()]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## ZADANIE: GENEROWANIE TEKSTU\n", "\n", "Napisać funkcję generującą tekst, która dla podanego fragmentu generuje tekst.\n", "Generowanie tekstu ma wyglądać następująco: Z 10 najbardziej prawodpodobnych tokenów należy wylosować jeden, ala ma to byc token inny niż specjalny (UNK, BOS, EOS, PAD). \n", "\n", "wygenerować tekst o długości 30 tokenów" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.5" } }, "nbformat": 4, "nbformat_minor": 4 }