ISI-lstm-lm/lstm - ODPOWIEDZI.ipynb
2021-05-31 15:05:47 +02:00

1466 lines
47 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## importy"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/media/kuba/ssdsam/anaconda3/lib/python3.8/site-packages/gensim/similarities/__init__.py:15: UserWarning: The gensim.similarities.levenshtein submodule is disabled, because the optional Levenshtein package <https://pypi.org/project/python-Levenshtein/> is unavailable. Install Levenhstein (e.g. `pip install python-Levenshtein`) to suppress this warning.\n",
" warnings.warn(msg)\n"
]
}
],
"source": [
"from gensim.utils import tokenize\n",
"import numpy as np\n",
"import torch\n",
"from tqdm.notebook import tqdm"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"#device = 'cpu'"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using cuda device\n"
]
}
],
"source": [
"print('Using {} device'.format(device))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"device(type='cuda')"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"device"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## przygotowanie zbiorów"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"pan_tadeusz_path_train= '/home/kuba/Syncthing/przedmioty/2020-02/ISI/zajecia9_ngramowy_model_jDDezykowy/pan-tadeusz-train.txt'"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"pan_tadeusz_path_valid= '/home/kuba/Syncthing/przedmioty/2020-02/ISI/zajecia9_ngramowy_model_jDDezykowy/pan-tadeusz-test.txt'"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"corpora_train = open(pan_tadeusz_path_train).read()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"corpora_train_tokenized = list(tokenize(corpora_train,lowercase = True))"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"vocab_itos = sorted(set(corpora_train_tokenized))"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"16598"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(vocab_itos)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"vocab_itos = vocab_itos[:15005]\n",
"vocab_itos[15001] = \"<UNK>\"\n",
"vocab_itos[15002] = \"<BOS>\"\n",
"vocab_itos[15003] = \"<EOS>\"\n",
"vocab_itos[15004] = \"<PAD>\""
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"15005"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(vocab_itos)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"vocab_stoi = dict()\n",
"for i, token in enumerate(vocab_itos):\n",
" vocab_stoi[token] = i"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"NGRAMS = 5"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"def get_token_id(dataset):\n",
" token_ids = [vocab_stoi['<PAD>']] * (NGRAMS-1) + [vocab_stoi['<BOS>']]\n",
" for token in dataset:\n",
" try:\n",
" token_ids.append(vocab_stoi[token])\n",
" except KeyError:\n",
" token_ids.append(vocab_stoi['<UNK>'])\n",
" token_ids.append(vocab_stoi['<EOS>'])\n",
" return token_ids"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"train_ids = get_token_id(corpora_train_tokenized)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"[15004,\n",
" 15004,\n",
" 15004,\n",
" 15004,\n",
" 15002,\n",
" 7,\n",
" 5002,\n",
" 7247,\n",
" 11955,\n",
" 1432,\n",
" 7018,\n",
" 14739,\n",
" 5506,\n",
" 4696,\n",
" 4276,\n",
" 7505,\n",
" 2642,\n",
" 8477,\n",
" 7259,\n",
" 10870,\n",
" 10530,\n",
" 7506,\n",
" 12968,\n",
" 7997,\n",
" 1911,\n",
" 12479,\n",
" 11129,\n",
" 13069,\n",
" 11797,\n",
" 5819]"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_ids[:30]"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"def get_samples(dataset):\n",
" samples = []\n",
" for i in range(len(dataset)-NGRAMS):\n",
" samples.append(dataset[i:i+NGRAMS])\n",
" return samples"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"train_ids = get_samples(train_ids)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"train_ids = torch.tensor(train_ids, device = device)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[15004, 15004, 15004, 15004, 15002],\n",
" [15004, 15004, 15004, 15002, 7],\n",
" [15004, 15004, 15002, 7, 5002],\n",
" [15004, 15002, 7, 5002, 7247],\n",
" [15002, 7, 5002, 7247, 11955],\n",
" [ 7, 5002, 7247, 11955, 1432],\n",
" [ 5002, 7247, 11955, 1432, 7018],\n",
" [ 7247, 11955, 1432, 7018, 14739],\n",
" [11955, 1432, 7018, 14739, 5506],\n",
" [ 1432, 7018, 14739, 5506, 4696],\n",
" [ 7018, 14739, 5506, 4696, 4276],\n",
" [14739, 5506, 4696, 4276, 7505],\n",
" [ 5506, 4696, 4276, 7505, 2642],\n",
" [ 4696, 4276, 7505, 2642, 8477],\n",
" [ 4276, 7505, 2642, 8477, 7259],\n",
" [ 7505, 2642, 8477, 7259, 10870],\n",
" [ 2642, 8477, 7259, 10870, 10530],\n",
" [ 8477, 7259, 10870, 10530, 7506],\n",
" [ 7259, 10870, 10530, 7506, 12968],\n",
" [10870, 10530, 7506, 12968, 7997],\n",
" [10530, 7506, 12968, 7997, 1911],\n",
" [ 7506, 12968, 7997, 1911, 12479],\n",
" [12968, 7997, 1911, 12479, 11129],\n",
" [ 7997, 1911, 12479, 11129, 13069],\n",
" [ 1911, 12479, 11129, 13069, 11797],\n",
" [12479, 11129, 13069, 11797, 5819],\n",
" [11129, 13069, 11797, 5819, 6268],\n",
" [13069, 11797, 5819, 6268, 2807],\n",
" [11797, 5819, 6268, 2807, 7831],\n",
" [ 5819, 6268, 2807, 7831, 12893]], device='cuda:0')"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_ids[:30]"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([57022, 5])"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_ids.shape"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"corpora_valid = open(pan_tadeusz_path_valid).read()"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"corpora_valid_tokenized = list(tokenize(corpora_valid,lowercase = True))"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"valid_ids = get_token_id(corpora_valid_tokenized)"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"valid_ids = torch.tensor(get_samples(valid_ids), dtype = torch.long, device = device)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## model"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
"# https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"class LSTM(torch.nn.Module):\n",
"\n",
" def __init__(self):\n",
" super(LSTM, self).__init__()\n",
" self.emb = torch.nn.Embedding(len(vocab_itos),100)\n",
" self.rec = torch.nn.LSTM(100, 256, 1, batch_first = True)\n",
" self.fc1 = torch.nn.Linear( 256 ,len(vocab_itos))\n",
" #self.dropout = torch.nn.Dropout(0.5)\n",
"\n",
" def forward(self, x):\n",
" emb = self.emb(x)\n",
" #emb = self.dropout(emb)\n",
" output, (h_n, c_n) = self.rec(emb)\n",
" hidden = h_n.squeeze(0)\n",
" out = self.fc1(hidden)\n",
" #out = self.dropout(out)\n",
" return out"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"lm = LSTM().to(device)"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
"criterion = torch.nn.CrossEntropyLoss()"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
"optimizer = torch.optim.Adam(lm.parameters(),lr=0.0001)"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [],
"source": [
"BATCH_SIZE = 128\n",
"EPOCHS = 15"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [],
"source": [
"def get_ppl(dataset_ids):\n",
" lm.eval()\n",
"\n",
" batches = 0\n",
" loss_sum =0\n",
" acc_score = 0\n",
"\n",
" for i in range(0, len(dataset_ids)-BATCH_SIZE+1, BATCH_SIZE):\n",
" X = dataset_ids[i:i+BATCH_SIZE,:NGRAMS-1]\n",
" Y = dataset_ids[i:i+BATCH_SIZE,NGRAMS-1]\n",
" predictions = lm(X)\n",
" \n",
" # equally distributted\n",
" # predictions = torch.zeros_like(predictions)\n",
" \n",
" loss = criterion(predictions,Y)\n",
"\n",
" loss_sum += loss.item()\n",
" batches += 1\n",
"\n",
" return np.exp(loss_sum / batches)"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2c8985983fcd4ec5b93ba32b0345d000",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=445.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"epoch: 0\n",
"train ppl: 2296.6914856482526\n",
"valid ppl: 528.9542436139727\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d2acde1a6e19416cb1eaee8f40f0114f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=445.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"epoch: 1\n",
"train ppl: 2093.302103954666\n",
"valid ppl: 514.4726844027333\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c9df3b66aede47febbe6852ca62eed17",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=445.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"epoch: 2\n",
"train ppl: 2014.09679023559\n",
"valid ppl: 510.12146471773366\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d5272c8b25f74c5a9383a829ed30308e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=445.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"epoch: 3\n",
"train ppl: 1939.0594855086504\n",
"valid ppl: 509.1060151440451\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d6656f2e796b44b7b03d204ecd89f89b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=445.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"epoch: 4\n",
"train ppl: 1854.4566511885196\n",
"valid ppl: 510.02244291272973\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "86419de761c1443c86c70726099b9838",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=445.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"epoch: 5\n",
"train ppl: 1755.030202547313\n",
"valid ppl: 508.494174178397\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5d12b6fcfc0045c9b8bac58343036dff",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=445.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"epoch: 6\n",
"train ppl: 1646.180912657662\n",
"valid ppl: 506.06383737670035\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d6593e6387d5439a9c1301a366fecda6",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=445.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"epoch: 7\n",
"train ppl: 1533.0501876139222\n",
"valid ppl: 504.08067276707567\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "32742824eb00458b8586795a7fb56d84",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=445.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"epoch: 8\n",
"train ppl: 1420.680717507558\n",
"valid ppl: 502.6906095632547\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5ee78d1ebd03415885d4ea2663c0ee7f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=445.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"epoch: 9\n",
"train ppl: 1311.1083504083306\n",
"valid ppl: 503.5230045363773\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b30185826c4341b088934df92839852c",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=445.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"epoch: 10\n",
"train ppl: 1203.498635587493\n",
"valid ppl: 505.7599916969862\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "7eeeeb3ef93a45f3a6709a8579164c79",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=445.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"epoch: 11\n",
"train ppl: 1100.0681613054269\n",
"valid ppl: 507.6071195979723\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "919fd04ac9ed476e931149c15dc460db",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=445.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"epoch: 12\n",
"train ppl: 1003.217414775517\n",
"valid ppl: 510.07952767103245\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a30dc2c9bd194ead94be7de3275a1e5e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=445.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"epoch: 13\n",
"train ppl: 912.2987798296267\n",
"valid ppl: 512.8275727599236\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "94f2bac40c83416fad59742f2e2adff4",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=445.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"epoch: 14\n",
"train ppl: 826.911431868259\n",
"valid ppl: 516.1525759633064\n",
"\n"
]
}
],
"source": [
"history_ppl_train = []\n",
"history_ppl_valid = []\n",
"for epoch in range(EPOCHS):\n",
" \n",
" batches = 0\n",
" loss_sum =0\n",
" acc_score = 0\n",
" lm.train()\n",
" #for i in range(0, len(train_ids)-BATCH_SIZE+1, BATCH_SIZE):\n",
" for i in tqdm(range(0, len(train_ids)-BATCH_SIZE+1, BATCH_SIZE)):\n",
" X = train_ids[i:i+BATCH_SIZE,:NGRAMS-1]\n",
" Y = train_ids[i:i+BATCH_SIZE,NGRAMS-1]\n",
" predictions = lm(X)\n",
" loss = criterion(predictions,Y)\n",
" \n",
" \n",
" \n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
" \n",
" loss_sum += loss.item()\n",
" batches += 1\n",
" \n",
" ppl_train = get_ppl(train_ids)\n",
" ppl_valid = get_ppl(valid_ids)\n",
" \n",
" history_ppl_train.append(ppl_train)\n",
" history_ppl_valid.append(ppl_valid)\n",
" \n",
" print('epoch: ', epoch)\n",
" print('train ppl: ', ppl_train)\n",
" print('valid ppl: ', ppl_valid)\n",
" print()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## parametry modelu"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {
"scrolled": false
},
"outputs": [
{
"data": {
"text/plain": [
"[Parameter containing:\n",
" tensor([[ 0.4949, -1.2472, -0.7167, ..., -0.0801, -2.1905, 0.9790],\n",
" [ 0.2070, 1.7394, 0.8255, ..., -0.5796, -0.3776, 0.8831],\n",
" [-1.4559, 1.1073, -0.4904, ..., -1.2919, -2.2661, -0.5476],\n",
" ...,\n",
" [-0.3706, 0.2133, 0.0484, ..., -0.5792, -0.5769, -0.6941],\n",
" [ 0.5502, -0.1212, -2.0879, ..., 0.6764, -0.5961, -0.6282],\n",
" [ 0.8362, 0.2193, -0.0807, ..., 2.7741, -0.2589, 0.3310]],\n",
" device='cuda:0', requires_grad=True),\n",
" Parameter containing:\n",
" tensor([[ 0.0252, -0.0744, 0.0817, ..., -0.0559, 0.0896, 0.0208],\n",
" [ 0.0423, 0.0329, -0.0610, ..., -0.0009, 0.0169, -0.0361],\n",
" [ 0.0507, -0.0838, 0.0520, ..., 0.0395, 0.0067, 0.0173],\n",
" ...,\n",
" [ 0.0669, 0.0430, -0.0306, ..., 0.0096, 0.0619, -0.0992],\n",
" [-0.0153, -0.0888, 0.0580, ..., -0.0433, 0.0399, -0.0494],\n",
" [-0.0067, -0.0053, -0.0242, ..., 0.0017, -0.0306, -0.0972]],\n",
" device='cuda:0', requires_grad=True),\n",
" Parameter containing:\n",
" tensor([[ 0.0212, -0.0425, -0.0329, ..., -0.0206, 0.0839, 0.0286],\n",
" [ 0.0952, 0.0298, -0.1211, ..., -0.0468, -0.0233, -0.0620],\n",
" [ 0.0108, 0.0422, -0.0492, ..., -0.0288, -0.0231, 0.0078],\n",
" ...,\n",
" [ 0.0253, 0.0154, -0.0765, ..., -0.0025, 0.0057, -0.0408],\n",
" [ 0.0892, -0.0928, -0.1039, ..., -0.1531, -0.0011, 0.0180],\n",
" [ 0.1341, 0.0666, -0.0548, ..., -0.0573, -0.0376, -0.0813]],\n",
" device='cuda:0', requires_grad=True),\n",
" Parameter containing:\n",
" tensor([0.1115, 0.0423, 0.1438, ..., 0.0361, 0.0297, 0.0956], device='cuda:0',\n",
" requires_grad=True),\n",
" Parameter containing:\n",
" tensor([0.0786, 0.1201, 0.0857, ..., 0.1177, 0.1319, 0.0886], device='cuda:0',\n",
" requires_grad=True),\n",
" Parameter containing:\n",
" tensor([[-0.0158, 0.0236, -0.0958, ..., -0.0906, -0.0678, 0.0057],\n",
" [-0.0871, -0.0788, 0.1217, ..., -0.0231, -0.0102, 0.0220],\n",
" [ 0.0265, -0.0680, -0.0219, ..., -0.0520, -0.0565, 0.0628],\n",
" ...,\n",
" [-0.0618, 0.0232, 0.0898, ..., 0.1069, -0.0112, 0.0103],\n",
" [-0.0489, 0.0708, 0.0546, ..., 0.1186, -0.0987, 0.1411],\n",
" [-0.0764, 0.0463, 0.0947, ..., 0.1104, -0.0312, 0.1118]],\n",
" device='cuda:0', requires_grad=True),\n",
" Parameter containing:\n",
" tensor([ 0.0299, -0.0551, -0.0323, ..., -0.0371, -0.0297, -0.0157],\n",
" device='cuda:0', requires_grad=True)]"
]
},
"execution_count": 35,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"list(lm.parameters())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### krzywe uczenia"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"[528.9542436139727,\n",
" 514.4726844027333,\n",
" 510.12146471773366,\n",
" 509.1060151440451,\n",
" 510.02244291272973,\n",
" 508.494174178397,\n",
" 506.06383737670035,\n",
" 504.08067276707567,\n",
" 502.6906095632547,\n",
" 503.5230045363773,\n",
" 505.7599916969862,\n",
" 507.6071195979723,\n",
" 510.07952767103245,\n",
" 512.8275727599236,\n",
" 516.1525759633064]"
]
},
"execution_count": 36,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"history_ppl_valid"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x7f02842a99a0>,\n",
" <matplotlib.lines.Line2D at 0x7f02842a9a90>]"
]
},
"execution_count": 38,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAX0AAAD4CAYAAAAAczaOAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAAoWElEQVR4nO3dd5xU9b3/8ddnG71L3QUWBFFEaRtAOhoVe4kkVlARNPZoojH+bpLfvdd784smKho1qDQjGGLDGgvSVNoiRUSBpS/g0pRetnx+f8wBx2UXlm0zO/N+Ph7zmDOfOWfms7D7Pme+58w55u6IiEh8SIh0AyIiUnkU+iIicUShLyISRxT6IiJxRKEvIhJHkiLdwPGcdNJJnp6eHuk2RESqlAULFmxz98aF61Ef+unp6WRmZka6DRGRKsXM1hVV1/COiEgcUeiLiMQRhb6ISBxR6IuIxBGFvohIHFHoi4jEEYW+iEgcidnQnzh3PTNXbI10GyIiUSUmQ/9QXgEvz13HiAmZzF61PdLtiIhEjZgM/ZSkBCbc3IPWjWoyfPx8MtfuiHRLIiJR4bihb2YtzWyamX1tZl+Z2T1B/VEz+8bMlpjZG2ZWP6inm9l+M1sU3J4Le63uZvalmWWZ2Sgzs4r6wRrVrsY/bulJs7rVuXHsfBZt+L6i3kpEpMooyZZ+HnC/u58G9ALuMLOOwEdAJ3c/E1gBPBS2zCp37xLcbgurPwuMBNoHt8Hl8UMUp0md6kwc0YuGtVIY+uJclm7cWZFvJyIS9Y4b+u6+2d2/CKZ3A18Dqe7+obvnBbPNAdKO9Tpm1hyo6+6zPXRh3gnA5WVpviSa1avOxBE9qVM9mRtenMvyb3dX9FuKiEStExrTN7N0oCswt9BTNwPvhz1uY2YLzWyGmfULaqlAdtg82UGtqPcZaWaZZpa5dWvZj8BJa1CTiSN6kpKUwHUvzCFry54yv6aISFVU4tA3s9rAa8C97r4rrP4woSGgl4PSZqCVu3cF7gMmmlldoKjxey/qvdx9tLtnuHtG48ZHnQ66VFo3qsXEEb0A49rn57B2295yeV0RkaqkRKFvZsmEAv9ld389rD4MuBi4Lhiywd0Puvv2YHoBsAo4hdCWffgQUBqwqTx+iJI6uXFtJo7oSV6Bc90Lc8n+bl9lvr2ISMSV5OgdA14Evnb3v4bVBwMPApe6+76wemMzSwym2xLaYbva3TcDu82sV/CaQ4Ep5frTlMApTevw0vAe7D6QyzXPz2Hzzv2V3YKISMSUZEu/D3ADcHbYYZgXAk8DdYCPCh2a2R9YYmaLgVeB29z98IHyvwReALIIfQII3w9QaU5vUY+Xhvfk+725XPv8XLbsOhCJNkREKp0FozJRKyMjwyvqcokL1u3ghhfnkVq/Bq+M7EWj2tUq5H1ERCqbmS1w94zC9Zj8Rm5JdW/dkDE3/oQN3+3j+hfn8f2+Q5FuSUSkQsV16AP0atuI54dmsGrrHm54cR67DuRGuiURkQoT96EP0K99Y567vhvffLuLYWPmsedg3vEXEhGpghT6gbNPbcpT13RjSfZObh47n32HFPwiEnsU+mEGd2rGE7/oQua6HYyYkMmB3PxItyQiUq4U+oVc0rkFjw3pzOertnPbPxZwME/BLyKxQ6FfhCu7pfE/V5zB9OVbuXPiQnLzCyLdkohIuVDoF+OaHq34z8tO56NlOdz7yiLyFPwiEgOSIt1ANBt6VjqH8gr473e/JiUpgceGdCYxocKu+yIiUuEU+sdxS7+2HMwr4NEPlpOSmMD/XnkGCQp+EamiFPolcMegdhzMzWfUJ1kkJxn/dVknKvBKjyIiFUahX0K/OvcUDuYX8PcZq8lc+x3X9mzF5V1TqVs9OdKtiYiUWFyfcO1EuTv/WpDNhNlrWbpxF9WTE7jkzBZc27MVXVrW19a/iESN4k64ptAvpSXZ3zNx7nreWryJfYfyOa153dDWf5cW1NHWv4hEmEK/guw+kMuURZuYOHc9yzbvokZyIpd2Dm39n5lWT1v/IhIRCv0K5u4szt7JpGDrf39uPqe3CG39X9YlldrVtPtERCqPQr8S7TqQy5SFG3l57nq++XY3NVMSuaxLC67t0Zoz0upFuj0RiQMK/QhwdxZtCI39v71kEwdyCzgjtR7X9mzFpZ1bUEtb/yJSQRT6EbZzfy5TFm1kYrD1Xyslkcu6pnJtj1Z0StXWv4iUr1KHvpm1BCYAzYACYLS7P2lmDYF/AunAWuDn7v5dsMxDwHAgH7jb3T8I6t2BcUAN4D3gHj9OA7ES+oe5O1+sD239v7NkEwfzCuicVo9rerTi4s4tNPYvIuWiLKHfHGju7l+YWR1gAXA5cCOww93/ZGa/BRq4+4Nm1hGYBPQAWgAfA6e4e76ZzQPuAeYQCv1R7v7+sd4/1kI/3M59ubyxMJuJ89azImcPNZITufCM5vw8I40ebRrqyB8RKbXiQv+4m5XuvhnYHEzvNrOvgVTgMmBgMNt4YDrwYFB/xd0PAmvMLAvoYWZrgbruPjtoaAKhlccxQz+W1auZzI192jCsdzpfrP+eVxds4O3Fm3nti2xaN6rJVd3S+Fn3NFrUrxHpVkUkRpzQWIKZpQNdgblA02CFgLtvNrMmwWyphLbkD8sOarnBdOF6Ue8zEhgJ0KpVqxNpsUoyM7q3bkD31g34/cWn8/7SzfwrM5u/fLSCv368gr7tTmJIRkvO69iU6smJkW5XRKqwEoe+mdUGXgPudfddxxh6KOoJP0b96KL7aGA0hIZ3StpjLKiRksiV3dK4slsaG3bs418LsnltQTZ3T1pI3epJXNYllSEZaZyRqi9+iciJK1Hom1kyocB/2d1fD8o5ZtY82MpvDmwJ6tlAy7DF04BNQT2tiLoUo2XDmtx37ince057Zq/ezuTMDUzO3MBLc9ZxarM6XNU9jSu6ptKodrVItyoiVURJduQaoTH7He5+b1j9UWB72I7chu7+gJmdDkzkhx25U4H2wY7c+cBdhIaH3gOecvf3jvX+sbwjtzR27s/lnSWbmJyZzeIN35OUYJxzWhOGdG/JwA6NSUrUxdBEpGxH7/QFZgFfEjpkE+B3hIJ7MtAKWA8McfcdwTIPAzcDeYSGg94P6hn8cMjm+8Bd8XbIZnlakbObVxdk8/oX2Wzbc4iTalfjZ91Cwz/tmtSJdHsiEkH6clYMy80vYPryrUzO3MC0b7aQV+B0aVmfIRlpXNK5hc75LxKHFPpxYtueg7y5cCOTMzewImcPNVMSubJbKjf2TtfWv0gcUejHGXdnSfZO/jFnHVMWb+JQXgH92p/EsLPSGXRqE13gXSTGKfTj2PY9B3ll/gZemr2Ob3cdoFXDmgw9qzVDMlpSr4aGfkRikUJfyM0v4MOvchj3+Rrmr/2OmimJ/KxbGsN6t9bQj0iMUejLjyzduJNxn6/lrbChnxt7pzOoQxMSNPQjUuUp9KVIhYd+WjeqydCz0hmSkaajfkSqMIW+HFNufgEffPUt4z5bS+a68KGfdNo1qR3p9kTkBCn0pcSODP0s2sSh/NDQz0190hl4ioZ+RKoKhb6csG17DvLKvPW8NGcdObsOauhHpApR6Eup5eYX8O+l3zL+8x+Gfq7p0YpbB7SlSZ3qkW5PRIqg0Jdy8WX2TsZ+toYpizeRnGhc37M1tw44mcZ1dKZPkWii0JdytXbbXkZ9spI3F26kWlIiN5zVmlv7t9VpnkWihEJfKsTqrXt46pMspiwKhf/Q3q25tf/JNKyVEunWROKaQl8qVNaWPTz1yUreWryJGsmJDOudzsh+bWmg8BeJCIW+VIqsLbt5cmoW7yzZRM3kRG7sk86Ifm2pX1PhL1KZFPpSqVbk7ObJqSt5d8lmaldL4qY+6dzSty31aupQT5HKoNCXiFj+7W6enLqC9778ljrVkripbxuG922js3uKVDCFvkTU15t38eTHK/n3V99Sp3oSw/u24ea+bfQlL5EKUlzoH/cq2mY2xsy2mNnSsNo/zWxRcFtrZouCerqZ7Q977rmwZbqb2ZdmlmVmo4ILrkucOK15XZ67oTvv3t2Xs9o24omPV9L3T58waupKdh/IjXR7InGjJBdG7w/sASa4e6cinv8LsNPd/9PM0oF3iplvHnAPMAd4Dxh1+ILpx6It/di0dONOnvh4JR9/nUO9GsmM6NeGYb3TqaMtf5FyUeotfXefCewo5kUN+Dkw6Thv3hyo6+6zPbSWmQBcXoK+JUZ1Sq3HC8MyePvOvmS0bsBjH66g35+n8fzM1RzMy490eyIx67ihfxz9gBx3XxlWa2NmC81shpn1C2qpQHbYPNlBrUhmNtLMMs0sc+vWrWVsUaLZGWn1ePHGnzDljj6cmVafR977mp/+dQZvLd5EtO9vEqmKyhr61/DjrfzNQCt37wrcB0w0s7pAUeP3xf5Fu/tod89w94zGjRuXsUWpCjq3rM+Em3vw0vAe1K6WzN2TFnL53z5j7urtkW5NJKaUOvTNLAm4Evjn4Zq7H3T37cH0AmAVcAqhLfu0sMXTgE2lfW+JXf3aN+adu/ry2JDO5Ow6yC9Gz2HEhExWbd0T6dZEYkJZtvR/Cnzj7keGbcyssZklBtNtgfbAanffDOw2s17BfoChwJQyvLfEsMQE46ruaUz79UB+c34HZq/aznmPz+T/vPkl2/YcjHR7IlVaSQ7ZnATMBjqYWbaZDQ+eupqjd+D2B5aY2WLgVeA2dz+8E/iXwAtAFqFPAMc9ckfiW42URO4Y1I7pvxnIdT1bMWneBgb8eRpPf7KS/Ye0s1ekNPTlLKkyVm3dw/97/xs+XJZDs7rVue+8U/hZtzQSdQlHkaOU+pBNkWhxcuPajB6aweRbz6Jpveo88OoSLho1i5krdISXSEkp9KXK6dGmIW/e3punrunK3kN5DB0zjxtenMvXm3dFujWRqKfQlyrJzLikcws+vm8A/+ei01iSvZMLR83i1/9azOad+yPdnkjU0pi+xISd+3L52/Qsxn22loQEuKVvW24d0FandZC4pTF9iWn1aibzuwtPY+r9AzivYzOenpbFwEen89LsteTmF0S6PZGoodCXmNKyYU1GXdOVKXf0oV2T2vzHlK84/4mZfPJNjk7rIIJCX2JU55b1eWVkL54fmgEON4/LZNjY+azM2R3p1kQiSqEvMcvMOLdjU/59b3/+4+KOLFz/HYOfnMUfpizl+32HIt2eSEQo9CXmpSQlMLxvG2b8ZhDX9GjJS3PWMeDR6Yz7bI3G+yXuKPQlbjSslcJ/X34G793Tj06pdfnj28u44MlZTF++JdKtiVQahb7EnVOb1eUfw3vy/NAM8vILuHHsfG4aO09n8pS4oNCXuHR4vP+DX/XndxeeSuba7zj/8Zn859vL2LlP1+yV2KXQl7hWLSmRkf1PZtpvBjIkoyVjP1/DwMem8dLsteRpvF9ikEJfBDipdjX+98ozePeufnRoVof/mPIVF436lE9Xbot0ayLlSqEvEqZji7pMGtGL567vxr7cPK5/cS63jM9kzba9kW5NpFwo9EUKMTMGd2rOR78awIODT2X2qm2c9/gMHnl3GbsOaLxfqjaFvkgxqicn8suBofH+K7qm8sKnaxj06HQmzl1PfoFO6SBVk0Jf5Dia1KnOn6/qzNt39qVt41r87o0vuWjULD5fpfF+qXoU+iIl1Cm1HpNvPYu/XduN3QfyuPb5udz6Uibrtmu8X6qOklwYfYyZbTGzpWG1P5rZRjNbFNwuDHvuITPLMrPlZnZ+WL27mX0ZPDfKzHRhU6lyzIyLzmzO1PsH8JvzOzBr5TZ++tcZ/M97X2u8X6qEkmzpjwMGF1F/3N27BLf3AMysI3A1cHqwzDNmlhjM/ywwEmgf3Ip6TZEqoXpyIncMasf0X4fG+5+ftZpBj07n5bnrdHy/RLXjhr67zwR2lPD1LgNecfeD7r4GyAJ6mFlzoK67z/bQSc0nAJeXsmeRqNGk7g/j/Sc3qc3Dbyzl4qd0fL9Er7KM6d9pZkuC4Z8GQS0V2BA2T3ZQSw2mC9eLZGYjzSzTzDK3bt1ahhZFKken1Hr8c2Qvnr2uG3sPHT6+fz6rdT4fiTKlDf1ngZOBLsBm4C9Bvahxej9GvUjuPtrdM9w9o3HjxqVsUaRymRkXnPHD8f1zVu/gvMdn8l/v6Hw+Ej1KFfrunuPu+e5eADwP9AieygZahs2aBmwK6mlF1EVizpHj+389kCEZaYz5LHQ+nwk6n49EgVKFfjBGf9gVwOEje94CrjazambWhtAO23nuvhnYbWa9gqN2hgJTytC3SNRrXKca/3vlmbxzV19ObVaX30/5iguenMWMFRqylMgpySGbk4DZQAczyzaz4cCfg8MvlwCDgF8BuPtXwGRgGfBv4A53zw9e6pfAC4R27q4C3i/vH0YkGp3eoh4TR/Rk9A3dOZRfwLAx87hp7Dyytmi8XyqfhQ6miV4ZGRmemZkZ6TZEysXBvHwmfL6OUVNXsi83nxt6teaec9rToFZKpFuTGGNmC9w9o3Bd38gVqUTVkhIZ0b8t038zkKt/0pIJs9cy8LHpjNX1eqWSKPRFIqBR7Wo8ckXoer1npNbj/769jPOfmMkn3+QQ7Z++pWpT6ItE0KnN6vLS8B68OCwDHG4el8nQMfNYtmlXpFuTGKXQF4kwM+Oc05ry73v78x8Xd2RJ9k4uemoWD7y6mJxdByLdnsQY7cgViTI79+Xy9LSVjP98HYkJxoj+bbm1f1tqVUuKdGtShWhHrkgVUa9mMg9f1JGP7xvAOac1YdTUlQx8bDqT5uniLVJ2Cn2RKNWqUU2evrYbr9/em1YNa/LQ619y4ZOzmL58i3b2Sqkp9EWiXLdWDXj1trN49rpuHMjL58ax87WzV0pNoS9SBYSfzO33F3fky42hnb2/+ddivt2pnb1SctqRK1IFaWevHI925IrEkMM7e6feP4CfdmzKqKkrGfBoaGevzuQpx6LQF6nCWjasyVPXdOWN23uT3ijY2TtqFtO0s1eKodAXiQFdWzXgX7edxXPXd+NQXgE3jZ3PDS9qZ68cTaEvEiPMjMGdmvPhrwbwh0s6snSTdvbK0bQjVyRG7dyfyzPTshj72VoSEuCWvm25dUBb6lRPjnRrUgm0I1ckztSrkcxDF57G1PsHcG7HZjw9LYsBj05n/OdrOZSnnb3xSqEvEuMO7+x9684+dGhahz+89RXnPj6Dd5ds1s7eOKTQF4kTZ6bVZ+KInoy96SdUT0rkjolfcPkznzN39fZItyaVqCTXyB1jZlvMbGlY7VEz+8bMlpjZG2ZWP6inm9l+M1sU3J4LW6Z7cF3dLDMbFVwgXUQqkZkxqEMT3runH49edSY5Ow/wi9FzuGX8fFbm7I50e1IJSrKlPw4YXKj2EdDJ3c8EVgAPhT23yt27BLfbwurPAiOB9sGt8GuKSCVJTDCGZLRk+m8G8sDgDsxdvYPzn5jJb19bonP4x7jjhr67zwR2FKp96O55wcM5QNqxXsPMmgN13X22hwYRJwCXl6pjESk31ZMTuX1gO2Y8MIgbe7fhtS+yGfDoNB77YDm7D+RGuj2pAOUxpn8z8H7Y4zZmttDMZphZv6CWCmSHzZMd1EQkCjSslcLvL+nI1PsGcl7YkT7jPlujI31iTJlC38weBvKAl4PSZqCVu3cF7gMmmlldoKjx+2IPGzCzkWaWaWaZW7duLUuLInICWjWqyaiwI33++PYyzn18Bu8s2aQjfWJEqUPfzIYBFwPXBUM2uPtBd98eTC8AVgGnENqyDx8CSgM2Fffa7j7a3TPcPaNx48albVFESqnwkT53TlzI5c98zhwd6VPllSr0zWww8CBwqbvvC6s3NrPEYLotoR22q919M7DbzHoFR+0MBaaUuXsRqTBFHelz9eg5DB83nxU60qfKKskhm5OA2UAHM8s2s+HA00Ad4KNCh2b2B5aY2WLgVeA2dz+8E/iXwAtAFqFPAOH7AUQkShU+0mfemh0MfmImD766ROf0qYJ07h0ROSE79h7i6U+yeGnOWhITjGG907mt/8k0qJUS6dYkTHHn3lHoi0iprN++j8c/XsGbizZSOyWJkf3bclPfNtTW1buigkJfRCrE8m9385cPl/Phshwa1Urh9kHtuK5nK6onJ0a6tbim0BeRCrVw/Xc89uFyPsvaTot61bnnp+35Wbc0khJ1iq9I0KmVRaRCdW3VgJdv6cXLt/SkSd3qPPjal5z7+EzeXryJgoLo3riMJwp9ESlXfdqdxBu39+b5oRmkJCZw16SFXPTUp3zyTY6+4BUFFPoiUu7MjHM7NuW9e/rxxC+6sPdgHjePy2TIc7N1KucIU+iLSIVJTDAu75rK1PsH8MgVndjw3T5+MXoOQ8fM48vsnZFuLy5pR66IVJoDuflMmL2WZ6av4vt9uVzQqRn3n3cK7ZrUiXRrMUdH74hI1Nh9IJcXZq3hhVmr2Z+bz5Xd0rjnnPa0bFgz0q3FDIW+iESdHXsP8ez0LMbPXoe7c22PVtxxdjua1Kke6daqPIW+iEStzTv3M2pqFpMzN5CSmMCw3umM6NeGRrWrRbq1KkuhLyJRb+22vTz+8QreWryJ6kmJXN+rFSP6t9WWfyko9EWkysjasodnpmXx5qKNJCcmcG3PVtza/2Sa1VP4l5RCX0SqnLXb9vK3aVm8sXAjCWb84ictuW3gyaTWrxHp1qKeQl9EqqwNO/bxzPRVvLpgAwBXdU/j9oHtdLTPMSj0RaTK2/j9fv4+YxWvzNtAvjtXdk3ljkHtSD+pVqRbizoKfRGJGTm7DvD3Gat5ee46cvMLuKxLKPzbNakd6daihkJfRGLOlt0HeGHWGl6avY4DeflcdEZz7jq7PR2a6Ru+Cn0RiVnb9xzkxU/XMP7ztew9lM8FnZpx59ntOL1FvUi3FjGlPp++mY0xsy1mtjSs1tDMPjKzlcF9g7DnHjKzLDNbbmbnh9W7m9mXwXOjzMzK4wcTEWlUuxoPDD6Vz357Nnef055Ps7Zx0ahPuWV8Jkuyv490e1GlJGfZHAcMLlT7LTDV3dsDU4PHmFlH4Grg9GCZZ8zs8DXTngVGAu2DW+HXFBEpk/o1U7jv3FP49MGzue/cU5i/dgeXPv0ZN46dx4J130W6vahw3NB395nAjkLly4DxwfR44PKw+ivuftDd1wBZQA8zaw7UdffZHhpPmhC2jIhIuapXIzm0xf/gIB4Y3IEl2Tv52bOfc+3zc5ixYmtcX8yltOfTb+rumwGC+yZBPRXYEDZfdlBLDaYL14tkZiPNLNPMMrdu3VrKFkUk3tWpnsztA9vx6YODePjC01i9dS/Dxszjgidn8ebCjeTmF0S6xUpX3hdRKWqc3o9RL5K7j3b3DHfPaNy4cbk1JyLxqWZKEiP6t2XmA4N4bEhn8guce/+5iIGPTufFT9ew92BepFusNKUN/ZxgyIbgfktQzwZahs2XBmwK6mlF1EVEKk1KUgJXdU/jg3v7M+bGDFIb1OC/3llG7z99wmMfLGfbnoORbrHClTb03wKGBdPDgClh9avNrJqZtSG0w3ZeMAS028x6BUftDA1bRkSkUiUkGGef2pTJt57F67f35qy2jfjb9Cx6/+kTfvfGl6zZtjfSLVaY4x6nb2aTgIHASUAO8AfgTWAy0ApYDwxx9x3B/A8DNwN5wL3u/n5QzyB0JFAN4H3gLi/B3hQdpy8ilWH11j08P2sNr32RTW5+AYNPb8atA06mS8v6kW6tVPTlLBGREtiy+wDjP1/LS7PXsetAHj3bNOS2ASczsENjqtLXixT6IiInYM/BPF6Zt54XP13D5p0H6NC0DiP7t+WSzi1ISSrvY2DKn0JfRKQUcvMLeHvxJv4+YzXLc3bTvF51hvdtw9U9WlG7WlKk2yuWQl9EpAzcnekrtvL3GauYs3oHdaoncUOv1tzYJz0qL+eo0BcRKSeLNnzP6JmreH/ptyQnJHBJ5xbc1CedTqnRc4I3hb6ISDlbs20vYz5dw6sLstmfm0+P9Ibc1Cedczs2JSkxsuP+Cn0RkQqyc38uk+dvYPzstWR/t5/U+jUYelZrfvGTltSvmRKRnhT6IiIVLL/A+WhZDuM+X8Oc1TuokZzIFd1Sual3Ou2bVu6FXRT6IiKVaNmmXYz7fA1vLtrEobwC+rU/iRt7pzOoQxMSEir+eH+FvohIBGzfc5BJ89bz0px15Ow6SHqjmgzrnc5V3dOoUz25wt5XoS8iEkG5+QW8v/Rbxn62hoXrv6d2tSSGZKQx7Kx00k+qVe7vp9AXEYkSizZ8z9jP1vDuks3ku3N2hybc1KcNfdo1KrdTPSj0RUSiTM6uA7w8Zx0vz13P9r2HOKVpbW7s3YYruqZSIyXx+C9wDAp9EZEodSA3n7cXb2LsZ2tZtnkX9Wokc3WPltw+sB31apRu3L+40I/+swaJiMS46smJDMloybt392XyrWfR++RGvJqZTUoFfMEres8WJCISZ8yMHm0a0qNNQ/YezCvzEE9RtKUvIhKFalXQGTwV+iIicUShLyISRxT6IiJxpNShb2YdzGxR2G2Xmd1rZn80s41h9QvDlnnIzLLMbLmZnV8+P4KIiJRUqfcUuPtyoAuAmSUCG4E3gJuAx939sfD5zawjcDVwOtAC+NjMTnH3/NL2ICIiJ6a8hnfOAVa5+7pjzHMZ8Iq7H3T3NUAW0KOc3l9EREqgvEL/amBS2OM7zWyJmY0xswZBLRXYEDZPdlA7ipmNNLNMM8vcunVrObUoIiJlDn0zSwEuBf4VlJ4FTiY09LMZ+MvhWYtYvMhzQLj7aHfPcPeMxo0bl7VFEREJlMeW/gXAF+6eA+DuOe6e7+4FwPP8MISTDbQMWy4N2FQO7y8iIiVUHqF/DWFDO2bWPOy5K4ClwfRbwNVmVs3M2gDtgXnl8P4iIlJCZfqer5nVBM4Fbg0r/9nMuhAaull7+Dl3/8rMJgPLgDzgDh25IyJSucoU+u6+D2hUqHbDMeZ/BHikLO8pIiKlp2/kiojEEYW+iEgcUeiLiMQRhb6ISBxR6IuIxBGFvohIHFHoi4jEEYW+iEgcUeiLiMQRhb6ISBxR6IuIxBGFvohIHFHoi4jEEYW+iEgcUeiLiMQRhb6ISBxR6IuIxJEyXTkrqk25E/ZthxoNoWYDqNEgmG4Ydh/UkqtHulsRkUpR1mvkrgV2A/lAnrtnmFlD4J9AOqFr5P7c3b8L5n8IGB7Mf7e7f1CW9z+mvIPw/XrYtAj274C8A8XPm1wzFP41GgQriLCVQ40GR68okqpDYgokJge3FEhIhgR9cBKJa+5QkPfDLT8XCvKDx7nBfX5Qzyv6lh82ferF5Z4r5bGlP8jdt4U9/i0w1d3/ZGa/DR4/aGYdgauB04EWwMdmdkqFXRz9Z8//+HHufti3I7QC2LcD9n/34+nwWs5Xofv934EXlPw9LbHolcHh6cSk4D4FEsKmE5NC81kCmAFWxDTF1I81TejePXQj/L6giFpQ/1GNImoeel1L+OG9jkwXuoX3U9w84a+TkBj6d/zRfULo3+uoWtjjhKRgOqGI5cPqR5ZLCKufyHOHp2NoBV9QAJ4fCqPwe/ejawX5od+HI/MU/FArfDtSD3u+oKCIWhHLHw5Kz/9xf0fV8ororVDtqJ8hL7T8kdfK+2G+H9WKmefI47B5Dt/KO84ezoGE8h2JqIjhncuAgcH0eGA68GBQf8XdDwJrzCwL6AHMroAejpZcA+qlhm4lVVAAB3cFK4fvflgR5B0IranzcyH/UGgNfni6yHrYcwVh07n7f1wvHLrFhXGx08XM/6MVgv14+kf3CcepEfYcP6w4ivqDP3LzY8+Hl8t/b0QccyVXeEVX1ErPCs1/+DWKWkmH/98e47ljrdgLB/HhAKvqjqzwD6/8i9ogSCi0cZAUWnH/6HEiJKUcf54jr53041pC8g+1xKSw58NuiclhyySFLZMY9lz4/Cnl/s9V1tB34EMzc+Dv7j4aaOrumwHcfbOZNQnmTQXmhC2bHdSOYmYjgZEArVq1KmOLZZCQADXqh24NI9dGTPOwQDrWFmVxW37Fbs3lFzFv2BZm+NblUUFYzNZqQbD1G/4aFF6Zhf88RazgjloZFp7Op+gV8rFWzBw975EVUHjtcGCFTxfxaehHn2iK+FQV/mmr8Kejw5+CjqzECn+CSihiuWKWDQ/zIj/tJcbWJ65KUtbQ7+Pum4Jg/8jMvjnGvFZErcjNvGDlMRogIyOjCm8KynEd+eSRQCwfVyASLcq0mnT3TcH9FuANQsM1OWbWHCC43xLMng20DFs8DdhUlvcXEZETU+rQN7NaZlbn8DRwHrAUeAsYFsw2DJgSTL8FXG1m1cysDdAemFfa9xcRkRNXls/TTYE3LLRTLwmY6O7/NrP5wGQzGw6sB4YAuPtXZjYZWAbkAXdU2JE7IiJSpFKHvruvBjoXUd8OnFPMMo8Aj5T2PUVEpGy061tEJI4o9EVE4ohCX0Qkjij0RUTiiLlH93efzGwrsK6Ui58EbDvuXNGhKvUKVavfqtQrVK1+q1KvULX6LWuvrd29ceFi1Id+WZhZprtnRLqPkqhKvULV6rcq9QpVq9+q1CtUrX4rqlcN74iIxBGFvohIHIn10B8d6QZOQFXqFapWv1WpV6ha/ValXqFq9Vshvcb0mL6IiPxYrG/pi4hIGIW+iEgcicnQN7PBZrbczLKC6/RGLTNraWbTzOxrM/vKzO6JdE/HY2aJZrbQzN6JdC/HY2b1zexVM/sm+Dc+K9I9FcfMfhX8Diw1s0lmVr4XRy0jMxtjZlvMbGlYraGZfWRmK4P7BpHsMVwx/T4a/C4sMbM3zKx+BFs8oqhew577tZm5mZ1UHu8Vc6FvZonA34ALgI7ANcFF2aNVHnC/u58G9ALuiPJ+Ae4Bvo50EyX0JPBvdz+V0Flho7JvM0sF7gYy3L0TkAhcHdmujjIOGFyo9ltgqru3B6YGj6PFOI7u9yOgk7ufCawAHqrspooxjqN7xcxaAucSOk19uYi50Cd09a4sd1/t7oeAVwhdlD0quftmd/8imN5NKJRO4OrtlcvM0oCLgBci3cvxmFldoD/wIoC7H3L37yPa1LElATXMLAmoSZRdWc7dZwI7CpUvA8YH0+OByyuzp2Mpql93/9Dd84KHcwhdwS/iivm3BXgceIBiLi1bGrEY+qnAhrDHxV6APdqYWTrQFZgb4VaO5QlCv4QFEe6jJNoCW4GxwXDUC8FV3qKOu28EHiO0RbcZ2OnuH0a2qxJp6u6bIbQBAzSJcD8n4mbg/Ug3URwzuxTY6O6Ly/N1YzH0S3wB9mhiZrWB14B73X1XpPspipldDGxx9wWR7qWEkoBuwLPu3hXYS3QNPxwRjIVfBrQBWgC1zOz6yHYVu8zsYUJDqy9HupeimFlN4GHg9+X92rEY+lXuAuxmlkwo8F9299cj3c8x9AEuNbO1hIbNzjazf0S2pWPKBrLd/fAnp1cJrQSi0U+BNe6+1d1zgdeB3hHuqSRyzKw5QHC/JcL9HJeZDQMuBq7z6P2i0smENgAWB39vacAXZtasrC8ci6E/H2hvZm3MLIXQzrC3ItxTsSx0keEXga/d/a+R7udY3P0hd09z93RC/66fuHvUbo26+7fABjPrEJTOIXSN5mi0HuhlZjWD34lziNKdzoW8BQwLpocBUyLYy3GZ2WDgQeBSd98X6X6K4+5funsTd08P/t6ygW7B73SZxFzoBztp7gQ+IPRHM9ndv4psV8fUB7iB0FbzouB2YaSbiiF3AS+b2RKgC/A/kW2naMGnkVeBL4AvCf1tRtUpA8xsEjAb6GBm2WY2HPgTcK6ZrSR0lMmfItljuGL6fRqoA3wU/K09F9EmA8X0WjHvFb2fbkREpLzF3Ja+iIgUT6EvIhJHFPoiInFEoS8iEkcU+iIicUShLyISRxT6IiJx5P8DBjfbPAvZaBoAAAAASUVORK5CYII=\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.plot(np.arange(len(history_ppl_train)), history_ppl_train, history_ppl_valid)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Inferencja"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'Gości innych nie widział oprócz spółleśników'"
]
},
"execution_count": 39,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"'Gości innych nie widział oprócz spółleśników'"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [],
"source": [
"tokenized = list(tokenize('Gości innych nie widział oprócz spółleśników',lowercase = True))"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [],
"source": [
"#tokenized = tokenized[-NGRAMS :-1 ]"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['gości', 'innych', 'nie', 'widział', 'oprócz', 'spółleśników']"
]
},
"execution_count": 42,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tokenized"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
"outputs": [],
"source": [
"ids = []\n",
"for word in tokenized:\n",
" if word in vocab_stoi:\n",
" ids.append(vocab_stoi[word])\n",
" else:\n",
" ids.append(vocab_stoi['<UNK>'])"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[2671, 3168, 5873, 13240, 6938, 15001]"
]
},
"execution_count": 44,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ids"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"LSTM(\n",
" (emb): Embedding(15005, 100)\n",
" (rec): LSTM(100, 256, batch_first=True)\n",
" (fc1): Linear(in_features=256, out_features=15005, bias=True)\n",
")"
]
},
"execution_count": 45,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"lm.eval()"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [],
"source": [
"ids = torch.tensor(ids, dtype = torch.long, device = device)"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([ 2671, 3168, 5873, 13240, 6938, 15001], device='cuda:0')"
]
},
"execution_count": 47,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ids"
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {},
"outputs": [],
"source": [
"preds= lm(ids.unsqueeze(0))"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"15001"
]
},
"execution_count": 49,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.argmax(torch.softmax(preds,1),1).item()"
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.return_types.max(\n",
"values=tensor([0.1419], device='cuda:0', grad_fn=<MaxBackward0>),\n",
"indices=tensor([15001], device='cuda:0'))"
]
},
"execution_count": 50,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.max(torch.softmax(preds,1),1)"
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'<UNK>'"
]
},
"execution_count": 51,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"vocab_itos[torch.argmax(torch.softmax(preds,1),1).item()]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## ZADANIE: GENEROWANIE TEKSTU\n",
"\n",
"Napisać funkcję generującą tekst, która dla podanego fragmentu generuje tekst.\n",
"Generowanie tekstu ma wyglądać następująco: Z 10 najbardziej prawodpodobnych tokenów należy wylosować jeden, ala ma to byc token inny niż specjalny (UNK, BOS, EOS, PAD). \n",
"\n",
"wygenerować tekst o długości 30 tokenów"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### generowanie tekstu"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {},
"outputs": [],
"source": [
"tokenized = list(tokenize('Pan Tadeusz', lowercase = True))"
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['pan', 'tadeusz']"
]
},
"execution_count": 53,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tokenized"
]
},
{
"cell_type": "code",
"execution_count": 54,
"metadata": {},
"outputs": [],
"source": [
"ids = []\n",
"for word in tokenized:\n",
" if word in vocab_stoi:\n",
" ids.append(vocab_stoi[word])\n",
" else:\n",
" ids.append(vocab_stoi['<UNK>'])\n",
"ids = torch.tensor([ids], dtype = torch.long, device = device)"
]
},
{
"cell_type": "code",
"execution_count": 55,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"a\n",
"nie\n",
"ma\n",
"i\n",
"na\n",
"nim\n",
"na\n",
"w\n",
"tył\n",
"i\n",
"tak\n",
"w\n",
"tył\n",
"tylko\n",
"i\n",
"z\n",
"nim\n",
"na\n",
"litwie\n",
"a\n",
"tak\n",
"z\n",
"góry\n",
"w\n",
"górę\n",
"na\n",
"nie\n",
"a\n",
"tak\n",
"z\n"
]
}
],
"source": [
"candidates_number = 10\n",
"for i in range(30):\n",
" preds= lm(ids)\n",
" candidates = torch.topk(torch.softmax(preds,1),candidates_number)[1][0].cpu().numpy()\n",
" candidate = 15001\n",
" while candidate > 15000:\n",
" candidate = candidates[np.random.randint(candidates_number)]\n",
" print(vocab_itos[candidate])\n",
" ids = torch.cat((ids, torch.tensor([[candidate]], device = device)), 1)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"nbformat": 4,
"nbformat_minor": 4
}