ISI-lstm-lm/lstm - ODPOWIEDZI.ipynb
2021-05-31 15:05:47 +02:00

1466 lines
47 KiB
Plaintext

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## importy"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/media/kuba/ssdsam/anaconda3/lib/python3.8/site-packages/gensim/similarities/__init__.py:15: UserWarning: The gensim.similarities.levenshtein submodule is disabled, because the optional Levenshtein package <https://pypi.org/project/python-Levenshtein/> is unavailable. Install Levenhstein (e.g. `pip install python-Levenshtein`) to suppress this warning.\n",
" warnings.warn(msg)\n"
]
}
],
"source": [
"from gensim.utils import tokenize\n",
"import numpy as np\n",
"import torch\n",
"from tqdm.notebook import tqdm"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"#device = 'cpu'"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Using cuda device\n"
]
}
],
"source": [
"print('Using {} device'.format(device))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"device(type='cuda')"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"device"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## przygotowanie zbiorów"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"pan_tadeusz_path_train= '/home/kuba/Syncthing/przedmioty/2020-02/ISI/zajecia9_ngramowy_model_jDDezykowy/pan-tadeusz-train.txt'"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"pan_tadeusz_path_valid= '/home/kuba/Syncthing/przedmioty/2020-02/ISI/zajecia9_ngramowy_model_jDDezykowy/pan-tadeusz-test.txt'"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"corpora_train = open(pan_tadeusz_path_train).read()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"corpora_train_tokenized = list(tokenize(corpora_train,lowercase = True))"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"vocab_itos = sorted(set(corpora_train_tokenized))"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"16598"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(vocab_itos)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"vocab_itos = vocab_itos[:15005]\n",
"vocab_itos[15001] = \"<UNK>\"\n",
"vocab_itos[15002] = \"<BOS>\"\n",
"vocab_itos[15003] = \"<EOS>\"\n",
"vocab_itos[15004] = \"<PAD>\""
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"15005"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(vocab_itos)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"vocab_stoi = dict()\n",
"for i, token in enumerate(vocab_itos):\n",
" vocab_stoi[token] = i"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"NGRAMS = 5"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"def get_token_id(dataset):\n",
" token_ids = [vocab_stoi['<PAD>']] * (NGRAMS-1) + [vocab_stoi['<BOS>']]\n",
" for token in dataset:\n",
" try:\n",
" token_ids.append(vocab_stoi[token])\n",
" except KeyError:\n",
" token_ids.append(vocab_stoi['<UNK>'])\n",
" token_ids.append(vocab_stoi['<EOS>'])\n",
" return token_ids"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"train_ids = get_token_id(corpora_train_tokenized)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"[15004,\n",
" 15004,\n",
" 15004,\n",
" 15004,\n",
" 15002,\n",
" 7,\n",
" 5002,\n",
" 7247,\n",
" 11955,\n",
" 1432,\n",
" 7018,\n",
" 14739,\n",
" 5506,\n",
" 4696,\n",
" 4276,\n",
" 7505,\n",
" 2642,\n",
" 8477,\n",
" 7259,\n",
" 10870,\n",
" 10530,\n",
" 7506,\n",
" 12968,\n",
" 7997,\n",
" 1911,\n",
" 12479,\n",
" 11129,\n",
" 13069,\n",
" 11797,\n",
" 5819]"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_ids[:30]"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
"def get_samples(dataset):\n",
" samples = []\n",
" for i in range(len(dataset)-NGRAMS):\n",
" samples.append(dataset[i:i+NGRAMS])\n",
" return samples"
]
},
{
"cell_type": "code",
"execution_count": 19,
"metadata": {},
"outputs": [],
"source": [
"train_ids = get_samples(train_ids)"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {},
"outputs": [],
"source": [
"train_ids = torch.tensor(train_ids, device = device)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[15004, 15004, 15004, 15004, 15002],\n",
" [15004, 15004, 15004, 15002, 7],\n",
" [15004, 15004, 15002, 7, 5002],\n",
" [15004, 15002, 7, 5002, 7247],\n",
" [15002, 7, 5002, 7247, 11955],\n",
" [ 7, 5002, 7247, 11955, 1432],\n",
" [ 5002, 7247, 11955, 1432, 7018],\n",
" [ 7247, 11955, 1432, 7018, 14739],\n",
" [11955, 1432, 7018, 14739, 5506],\n",
" [ 1432, 7018, 14739, 5506, 4696],\n",
" [ 7018, 14739, 5506, 4696, 4276],\n",
" [14739, 5506, 4696, 4276, 7505],\n",
" [ 5506, 4696, 4276, 7505, 2642],\n",
" [ 4696, 4276, 7505, 2642, 8477],\n",
" [ 4276, 7505, 2642, 8477, 7259],\n",
" [ 7505, 2642, 8477, 7259, 10870],\n",
" [ 2642, 8477, 7259, 10870, 10530],\n",
" [ 8477, 7259, 10870, 10530, 7506],\n",
" [ 7259, 10870, 10530, 7506, 12968],\n",
" [10870, 10530, 7506, 12968, 7997],\n",
" [10530, 7506, 12968, 7997, 1911],\n",
" [ 7506, 12968, 7997, 1911, 12479],\n",
" [12968, 7997, 1911, 12479, 11129],\n",
" [ 7997, 1911, 12479, 11129, 13069],\n",
" [ 1911, 12479, 11129, 13069, 11797],\n",
" [12479, 11129, 13069, 11797, 5819],\n",
" [11129, 13069, 11797, 5819, 6268],\n",
" [13069, 11797, 5819, 6268, 2807],\n",
" [11797, 5819, 6268, 2807, 7831],\n",
" [ 5819, 6268, 2807, 7831, 12893]], device='cuda:0')"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_ids[:30]"
]
},
{
"cell_type": "code",
"execution_count": 22,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([57022, 5])"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"train_ids.shape"
]
},
{
"cell_type": "code",
"execution_count": 23,
"metadata": {},
"outputs": [],
"source": [
"corpora_valid = open(pan_tadeusz_path_valid).read()"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {},
"outputs": [],
"source": [
"corpora_valid_tokenized = list(tokenize(corpora_valid,lowercase = True))"
]
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [],
"source": [
"valid_ids = get_token_id(corpora_valid_tokenized)"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"valid_ids = torch.tensor(get_samples(valid_ids), dtype = torch.long, device = device)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## model"
]
},
{
"cell_type": "code",
"execution_count": 27,
"metadata": {},
"outputs": [],
"source": [
"# https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html"
]
},
{
"cell_type": "code",
"execution_count": 28,
"metadata": {},
"outputs": [],
"source": [
"class LSTM(torch.nn.Module):\n",
"\n",
" def __init__(self):\n",
" super(LSTM, self).__init__()\n",
" self.emb = torch.nn.Embedding(len(vocab_itos),100)\n",
" self.rec = torch.nn.LSTM(100, 256, 1, batch_first = True)\n",
" self.fc1 = torch.nn.Linear( 256 ,len(vocab_itos))\n",
" #self.dropout = torch.nn.Dropout(0.5)\n",
"\n",
" def forward(self, x):\n",
" emb = self.emb(x)\n",
" #emb = self.dropout(emb)\n",
" output, (h_n, c_n) = self.rec(emb)\n",
" hidden = h_n.squeeze(0)\n",
" out = self.fc1(hidden)\n",
" #out = self.dropout(out)\n",
" return out"
]
},
{
"cell_type": "code",
"execution_count": 29,
"metadata": {},
"outputs": [],
"source": [
"lm = LSTM().to(device)"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
"criterion = torch.nn.CrossEntropyLoss()"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [],
"source": [
"optimizer = torch.optim.Adam(lm.parameters(),lr=0.0001)"
]
},
{
"cell_type": "code",
"execution_count": 32,
"metadata": {},
"outputs": [],
"source": [
"BATCH_SIZE = 128\n",
"EPOCHS = 15"
]
},
{
"cell_type": "code",
"execution_count": 33,
"metadata": {},
"outputs": [],
"source": [
"def get_ppl(dataset_ids):\n",
" lm.eval()\n",
"\n",
" batches = 0\n",
" loss_sum =0\n",
" acc_score = 0\n",
"\n",
" for i in range(0, len(dataset_ids)-BATCH_SIZE+1, BATCH_SIZE):\n",
" X = dataset_ids[i:i+BATCH_SIZE,:NGRAMS-1]\n",
" Y = dataset_ids[i:i+BATCH_SIZE,NGRAMS-1]\n",
" predictions = lm(X)\n",
" \n",
" # equally distributted\n",
" # predictions = torch.zeros_like(predictions)\n",
" \n",
" loss = criterion(predictions,Y)\n",
"\n",
" loss_sum += loss.item()\n",
" batches += 1\n",
"\n",
" return np.exp(loss_sum / batches)"
]
},
{
"cell_type": "code",
"execution_count": 34,
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "2c8985983fcd4ec5b93ba32b0345d000",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=445.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"epoch: 0\n",
"train ppl: 2296.6914856482526\n",
"valid ppl: 528.9542436139727\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d2acde1a6e19416cb1eaee8f40f0114f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=445.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"epoch: 1\n",
"train ppl: 2093.302103954666\n",
"valid ppl: 514.4726844027333\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c9df3b66aede47febbe6852ca62eed17",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=445.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"epoch: 2\n",
"train ppl: 2014.09679023559\n",
"valid ppl: 510.12146471773366\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d5272c8b25f74c5a9383a829ed30308e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=445.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"epoch: 3\n",
"train ppl: 1939.0594855086504\n",
"valid ppl: 509.1060151440451\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d6656f2e796b44b7b03d204ecd89f89b",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=445.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"epoch: 4\n",
"train ppl: 1854.4566511885196\n",
"valid ppl: 510.02244291272973\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "86419de761c1443c86c70726099b9838",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=445.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"epoch: 5\n",
"train ppl: 1755.030202547313\n",
"valid ppl: 508.494174178397\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5d12b6fcfc0045c9b8bac58343036dff",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=445.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"epoch: 6\n",
"train ppl: 1646.180912657662\n",
"valid ppl: 506.06383737670035\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "d6593e6387d5439a9c1301a366fecda6",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=445.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"epoch: 7\n",
"train ppl: 1533.0501876139222\n",
"valid ppl: 504.08067276707567\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "32742824eb00458b8586795a7fb56d84",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=445.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"epoch: 8\n",
"train ppl: 1420.680717507558\n",
"valid ppl: 502.6906095632547\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5ee78d1ebd03415885d4ea2663c0ee7f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=445.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"epoch: 9\n",
"train ppl: 1311.1083504083306\n",
"valid ppl: 503.5230045363773\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b30185826c4341b088934df92839852c",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=445.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"epoch: 10\n",
"train ppl: 1203.498635587493\n",
"valid ppl: 505.7599916969862\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "7eeeeb3ef93a45f3a6709a8579164c79",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=445.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"epoch: 11\n",
"train ppl: 1100.0681613054269\n",
"valid ppl: 507.6071195979723\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "919fd04ac9ed476e931149c15dc460db",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=445.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"epoch: 12\n",
"train ppl: 1003.217414775517\n",
"valid ppl: 510.07952767103245\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a30dc2c9bd194ead94be7de3275a1e5e",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=445.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"epoch: 13\n",
"train ppl: 912.2987798296267\n",
"valid ppl: 512.8275727599236\n",
"\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "94f2bac40c83416fad59742f2e2adff4",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, max=445.0), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"epoch: 14\n",
"train ppl: 826.911431868259\n",
"valid ppl: 516.1525759633064\n",
"\n"
]
}
],
"source": [
"history_ppl_train = []\n",
"history_ppl_valid = []\n",
"for epoch in range(EPOCHS):\n",
" \n",
" batches = 0\n",
" loss_sum =0\n",
" acc_score = 0\n",
" lm.train()\n",
" #for i in range(0, len(train_ids)-BATCH_SIZE+1, BATCH_SIZE):\n",
" for i in tqdm(range(0, len(train_ids)-BATCH_SIZE+1, BATCH_SIZE)):\n",
" X = train_ids[i:i+BATCH_SIZE,:NGRAMS-1]\n",
" Y = train_ids[i:i+BATCH_SIZE,NGRAMS-1]\n",
" predictions = lm(X)\n",
" loss = criterion(predictions,Y)\n",
" \n",
" \n",
" \n",
" optimizer.zero_grad()\n",
" loss.backward()\n",
" optimizer.step()\n",
" \n",
" loss_sum += loss.item()\n",
" batches += 1\n",
" \n",
" ppl_train = get_ppl(train_ids)\n",
" ppl_valid = get_ppl(valid_ids)\n",
" \n",
" history_ppl_train.append(ppl_train)\n",
" history_ppl_valid.append(ppl_valid)\n",
" \n",
" print('epoch: ', epoch)\n",
" print('train ppl: ', ppl_train)\n",
" print('valid ppl: ', ppl_valid)\n",
" print()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## parametry modelu"
]
},
{
"cell_type": "code",
"execution_count": 35,
"metadata": {
"scrolled": false
},
"outputs": [
{
"data": {
"text/plain": [
"[Parameter containing:\n",
" tensor([[ 0.4949, -1.2472, -0.7167, ..., -0.0801, -2.1905, 0.9790],\n",
" [ 0.2070, 1.7394, 0.8255, ..., -0.5796, -0.3776, 0.8831],\n",
" [-1.4559, 1.1073, -0.4904, ..., -1.2919, -2.2661, -0.5476],\n",
" ...,\n",
" [-0.3706, 0.2133, 0.0484, ..., -0.5792, -0.5769, -0.6941],\n",
" [ 0.5502, -0.1212, -2.0879, ..., 0.6764, -0.5961, -0.6282],\n",
" [ 0.8362, 0.2193, -0.0807, ..., 2.7741, -0.2589, 0.3310]],\n",
" device='cuda:0', requires_grad=True),\n",
" Parameter containing:\n",
" tensor([[ 0.0252, -0.0744, 0.0817, ..., -0.0559, 0.0896, 0.0208],\n",
" [ 0.0423, 0.0329, -0.0610, ..., -0.0009, 0.0169, -0.0361],\n",
" [ 0.0507, -0.0838, 0.0520, ..., 0.0395, 0.0067, 0.0173],\n",
" ...,\n",
" [ 0.0669, 0.0430, -0.0306, ..., 0.0096, 0.0619, -0.0992],\n",
" [-0.0153, -0.0888, 0.0580, ..., -0.0433, 0.0399, -0.0494],\n",
" [-0.0067, -0.0053, -0.0242, ..., 0.0017, -0.0306, -0.0972]],\n",
" device='cuda:0', requires_grad=True),\n",
" Parameter containing:\n",
" tensor([[ 0.0212, -0.0425, -0.0329, ..., -0.0206, 0.0839, 0.0286],\n",
" [ 0.0952, 0.0298, -0.1211, ..., -0.0468, -0.0233, -0.0620],\n",
" [ 0.0108, 0.0422, -0.0492, ..., -0.0288, -0.0231, 0.0078],\n",
" ...,\n",
" [ 0.0253, 0.0154, -0.0765, ..., -0.0025, 0.0057, -0.0408],\n",
" [ 0.0892, -0.0928, -0.1039, ..., -0.1531, -0.0011, 0.0180],\n",
" [ 0.1341, 0.0666, -0.0548, ..., -0.0573, -0.0376, -0.0813]],\n",
" device='cuda:0', requires_grad=True),\n",
" Parameter containing:\n",
" tensor([0.1115, 0.0423, 0.1438, ..., 0.0361, 0.0297, 0.0956], device='cuda:0',\n",
" requires_grad=True),\n",
" Parameter containing:\n",
" tensor([0.0786, 0.1201, 0.0857, ..., 0.1177, 0.1319, 0.0886], device='cuda:0',\n",
" requires_grad=True),\n",
" Parameter containing:\n",
" tensor([[-0.0158, 0.0236, -0.0958, ..., -0.0906, -0.0678, 0.0057],\n",
" [-0.0871, -0.0788, 0.1217, ..., -0.0231, -0.0102, 0.0220],\n",
" [ 0.0265, -0.0680, -0.0219, ..., -0.0520, -0.0565, 0.0628],\n",
" ...,\n",
" [-0.0618, 0.0232, 0.0898, ..., 0.1069, -0.0112, 0.0103],\n",
" [-0.0489, 0.0708, 0.0546, ..., 0.1186, -0.0987, 0.1411],\n",
" [-0.0764, 0.0463, 0.0947, ..., 0.1104, -0.0312, 0.1118]],\n",
" device='cuda:0', requires_grad=True),\n",
" Parameter containing:\n",
" tensor([ 0.0299, -0.0551, -0.0323, ..., -0.0371, -0.0297, -0.0157],\n",
" device='cuda:0', requires_grad=True)]"
]
},
"execution_count": 35,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"list(lm.parameters())"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### krzywe uczenia"
]
},
{
"cell_type": "code",
"execution_count": 36,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"[528.9542436139727,\n",
" 514.4726844027333,\n",
" 510.12146471773366,\n",
" 509.1060151440451,\n",
" 510.02244291272973,\n",
" 508.494174178397,\n",
" 506.06383737670035,\n",
" 504.08067276707567,\n",
" 502.6906095632547,\n",
" 503.5230045363773,\n",
" 505.7599916969862,\n",
" 507.6071195979723,\n",
" 510.07952767103245,\n",
" 512.8275727599236,\n",
" 516.1525759633064]"
]
},
"execution_count": 36,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"history_ppl_valid"
]
},
{
"cell_type": "code",
"execution_count": 37,
"metadata": {},
"outputs": [],
"source": [
"import matplotlib.pyplot as plt"
]
},
{
"cell_type": "code",
"execution_count": 38,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x7f02842a99a0>,\n",
" <matplotlib.lines.Line2D at 0x7f02842a9a90>]"
]
},
"execution_count": 38,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.plot(np.arange(len(history_ppl_train)), history_ppl_train, history_ppl_valid)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Inferencja"
]
},
{
"cell_type": "code",
"execution_count": 39,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'Gości innych nie widział oprócz spółleśników'"
]
},
"execution_count": 39,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"'Gości innych nie widział oprócz spółleśników'"
]
},
{
"cell_type": "code",
"execution_count": 40,
"metadata": {},
"outputs": [],
"source": [
"tokenized = list(tokenize('Gości innych nie widział oprócz spółleśników',lowercase = True))"
]
},
{
"cell_type": "code",
"execution_count": 41,
"metadata": {},
"outputs": [],
"source": [
"#tokenized = tokenized[-NGRAMS :-1 ]"
]
},
{
"cell_type": "code",
"execution_count": 42,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['gości', 'innych', 'nie', 'widział', 'oprócz', 'spółleśników']"
]
},
"execution_count": 42,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tokenized"
]
},
{
"cell_type": "code",
"execution_count": 43,
"metadata": {},
"outputs": [],
"source": [
"ids = []\n",
"for word in tokenized:\n",
" if word in vocab_stoi:\n",
" ids.append(vocab_stoi[word])\n",
" else:\n",
" ids.append(vocab_stoi['<UNK>'])"
]
},
{
"cell_type": "code",
"execution_count": 44,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[2671, 3168, 5873, 13240, 6938, 15001]"
]
},
"execution_count": 44,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ids"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"LSTM(\n",
" (emb): Embedding(15005, 100)\n",
" (rec): LSTM(100, 256, batch_first=True)\n",
" (fc1): Linear(in_features=256, out_features=15005, bias=True)\n",
")"
]
},
"execution_count": 45,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"lm.eval()"
]
},
{
"cell_type": "code",
"execution_count": 46,
"metadata": {},
"outputs": [],
"source": [
"ids = torch.tensor(ids, dtype = torch.long, device = device)"
]
},
{
"cell_type": "code",
"execution_count": 47,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([ 2671, 3168, 5873, 13240, 6938, 15001], device='cuda:0')"
]
},
"execution_count": 47,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ids"
]
},
{
"cell_type": "code",
"execution_count": 48,
"metadata": {},
"outputs": [],
"source": [
"preds= lm(ids.unsqueeze(0))"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"15001"
]
},
"execution_count": 49,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.argmax(torch.softmax(preds,1),1).item()"
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.return_types.max(\n",
"values=tensor([0.1419], device='cuda:0', grad_fn=<MaxBackward0>),\n",
"indices=tensor([15001], device='cuda:0'))"
]
},
"execution_count": 50,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.max(torch.softmax(preds,1),1)"
]
},
{
"cell_type": "code",
"execution_count": 51,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'<UNK>'"
]
},
"execution_count": 51,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"vocab_itos[torch.argmax(torch.softmax(preds,1),1).item()]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## ZADANIE: GENEROWANIE TEKSTU\n",
"\n",
"Napisać funkcję generującą tekst, która dla podanego fragmentu generuje tekst.\n",
"Generowanie tekstu ma wyglądać następująco: Z 10 najbardziej prawodpodobnych tokenów należy wylosować jeden, ala ma to byc token inny niż specjalny (UNK, BOS, EOS, PAD). \n",
"\n",
"wygenerować tekst o długości 30 tokenów"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### generowanie tekstu"
]
},
{
"cell_type": "code",
"execution_count": 52,
"metadata": {},
"outputs": [],
"source": [
"tokenized = list(tokenize('Pan Tadeusz', lowercase = True))"
]
},
{
"cell_type": "code",
"execution_count": 53,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['pan', 'tadeusz']"
]
},
"execution_count": 53,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"tokenized"
]
},
{
"cell_type": "code",
"execution_count": 54,
"metadata": {},
"outputs": [],
"source": [
"ids = []\n",
"for word in tokenized:\n",
" if word in vocab_stoi:\n",
" ids.append(vocab_stoi[word])\n",
" else:\n",
" ids.append(vocab_stoi['<UNK>'])\n",
"ids = torch.tensor([ids], dtype = torch.long, device = device)"
]
},
{
"cell_type": "code",
"execution_count": 55,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"a\n",
"nie\n",
"ma\n",
"i\n",
"na\n",
"nim\n",
"na\n",
"w\n",
"tył\n",
"i\n",
"tak\n",
"w\n",
"tył\n",
"tylko\n",
"i\n",
"z\n",
"nim\n",
"na\n",
"litwie\n",
"a\n",
"tak\n",
"z\n",
"góry\n",
"w\n",
"górę\n",
"na\n",
"nie\n",
"a\n",
"tak\n",
"z\n"
]
}
],
"source": [
"candidates_number = 10\n",
"for i in range(30):\n",
" preds= lm(ids)\n",
" candidates = torch.topk(torch.softmax(preds,1),candidates_number)[1][0].cpu().numpy()\n",
" candidate = 15001\n",
" while candidate > 15000:\n",
" candidate = candidates[np.random.randint(candidates_number)]\n",
" print(vocab_itos[candidate])\n",
" ids = torch.cat((ids, torch.tensor([[candidate]], device = device)), 1)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"nbformat": 4,
"nbformat_minor": 4
}