{ "cells": [ { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "import nltk\n", "import torch\n", "import pandas as pd\n", "import csv\n", "from sklearn.model_selection import train_test_split\n", "from nltk.tokenize import word_tokenize as tokenize\n", "from tqdm.notebook import tqdm\n", "import numpy as np" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "[nltk_data] Downloading package punkt to /home/ubuntu/nltk_data...\n", "[nltk_data] Package punkt is already up-to-date!\n" ] }, { "data": { "text/plain": [ "True" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#downloads\n", "nltk.download('punkt')" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Using cpu device\n" ] } ], "source": [ "#settings\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "print('Using {} device'.format(device))\n", "criterion = torch.nn.CrossEntropyLoss()\n", "BATCH_SIZE = 128\n", "EPOCHS = 15\n", "NGRAMS = 5" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "28558\n", "15005\n" ] } ], "source": [ "#training data prepare\n", "train_data = pd.read_csv('train/train.tsv', header=None, error_bad_lines=False, quoting=csv.QUOTE_NONE, sep='\\t')\n", "train_data = train_data[0]\n", "train_set, train_test_set = train_test_split(train_data, test_size = 0.2)\n", "with open(\"train/train_set.tsv\", \"w\", encoding='utf-8') as out_train_set:\n", " for i in train_set:\n", " out_train_set.write(i)\n", "with open(\"train/train_test_set.tsv\", \"w\", encoding='utf-8') as out_train_test_set:\n", " for i in train_test_set:\n", " out_train_test_set.write(i)\n", "\n", "train_set_tok = list(tokenize(open('train/train_set.tsv').read()))\n", "train_set_tok = [line.lower() for line in train_set_tok]\n", "\n", "vocab_itos = sorted(set(train_set_tok))\n", "print(len(vocab_itos))\n", "\n", "vocab_itos = vocab_itos[:15005]\n", "vocab_itos[15001] = \"\"\n", "vocab_itos[15002] = \"\"\n", "vocab_itos[15003] = \"\"\n", "vocab_itos[15004] = \"\"\n", "\n", "print(len(vocab_itos))\n", "\n", "vocab_stoi = dict()\n", "for i, token in enumerate(vocab_itos):\n", " vocab_stoi[token] = i\n", "\n", "\n", "\n", "train_ids = [vocab_stoi['']] * (NGRAMS-1) + [vocab_stoi['']]\n", "for token in train_set_tok:\n", " try:\n", " train_ids.append(vocab_stoi[token])\n", " except KeyError:\n", " train_ids.append(vocab_stoi[''])\n", "train_ids.append(vocab_stoi[''])\n", "\n", "\n", "samples = []\n", "for i in range(len(train_ids)-NGRAMS):\n", " samples.append(train_ids[i:i+NGRAMS])\n", "train_ids = torch.tensor(samples,device=device)\n", "\n", "\n", "train_test_set_tok = list(tokenize(open('train/train_test_set.tsv').read()))\n", "train_test_set_tok = [line.lower() for line in train_test_set_tok]\n", "\n", "train_test_ids = [vocab_stoi['']] * (NGRAMS-1) + [vocab_stoi['']]\n", "for token in train_test_set_tok:\n", " try:\n", " train_test_ids.append(vocab_stoi[token])\n", " except KeyError:\n", " train_test_ids.append(vocab_stoi[''])\n", "train_test_ids.append(vocab_stoi[''])\n", "\n", "\n", "samples = []\n", "for i in range(len(train_test_ids)-NGRAMS):\n", " samples.append(train_test_ids[i:i+NGRAMS])\n", "train_test_ids = torch.tensor(samples, dtype=torch.long, device=device)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "#GRU\n", "class GRU(torch.nn.Module):\n", "\n", " def __init__(self):\n", " super(GRU, self).__init__()\n", " self.emb = torch.nn.Embedding(len(vocab_itos),100)\n", " self.rec = torch.nn.GRU(100, 256, 1, batch_first = True)\n", " self.fc1 = torch.nn.Linear( 256 ,len(vocab_itos))\n", " self.dropout = torch.nn.Dropout(0.5)\n", "\n", " def forward(self, x):\n", " emb = self.emb(x)\n", " #emb = self.dropout(emb)\n", " output, h_n = self.rec(emb)\n", " hidden = h_n.squeeze(0)\n", " out = self.fc1(hidden)\n", " out = self.dropout(out)\n", " return out\n", "lm = GRU().to(device)\n", "optimizer = torch.optim.Adam(lm.parameters(),lr=0.0001)" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "56b8d3f9424a4a6ca15ea27c705ead10", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=1629.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "epoch: 0\n", "train ppl: 429.60890594777385\n", "train_test ppl: 354.7605940026038\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "bf252622fa70442aa21dc391275818d3", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=1629.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "epoch: 1\n", "train ppl: 385.04263303807164\n", "train_test ppl: 320.5323274780826\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "700fad78591b4cf18ac03e48628c4535", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=1629.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "epoch: 2\n", "train ppl: 388.15715746591627\n", "train_test ppl: 331.5143312260392\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "d1b46286cde6423195b0e0321cf4cb37", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=1629.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "epoch: 3\n", "train ppl: 364.4566197255965\n", "train_test ppl: 316.9918140368464\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "df3ff22f10cd40bb9758da63481e99e2", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=1629.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "epoch: 4\n", "train ppl: 344.1713452631125\n", "train_test ppl: 306.67499426384535\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "20dd67a95f81488dad61194310b0c5b1", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=1629.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "epoch: 5\n", "train ppl: 325.7237671473614\n", "train_test ppl: 295.83423173746667\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "306a7f0b7bd340cbafe5ecc784a1738e", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=1629.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "epoch: 6\n", "train ppl: 323.8838574773216\n", "train_test ppl: 302.95495879615413\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "972f556564a44554880d446cc0a3b126", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=1629.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "epoch: 7\n", "train ppl: 313.13238735049896\n", "train_test ppl: 300.0722307805052\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "740454f9d4544c1bbdd6411a13f9ad75", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=1629.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "epoch: 8\n", "train ppl: 308.2248282795148\n", "train_test ppl: 303.25779664571974\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "73a03968179942bebfecc8f35928c016", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=1629.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "epoch: 9\n", "train ppl: 293.68307666273853\n", "train_test ppl: 295.00145166486533\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "f6b3bb79ccd84e06909e91a7e6678ee6", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=1629.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "epoch: 10\n", "train ppl: 279.2453691179102\n", "train_test ppl: 287.8307587065576\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "a5ba1fd4d2434b18a41955f46e8b4c82", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=1629.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "epoch: 11\n", "train ppl: 267.2034758169644\n", "train_test ppl: 282.18074183208086\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "a08ea62337764cd4b72b25b14ea609a2", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=1629.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "epoch: 12\n", "train ppl: 260.65159391269935\n", "train_test ppl: 281.92398288442536\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "0f7ebcb5d21a47e78875a829e71fc0c7", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=1629.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "epoch: 13\n", "train ppl: 246.21807765812747\n", "train_test ppl: 271.8481103799856\n", "\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "29773930d3c246079b26e9e6d4da84fd", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatProgress(value=0.0, max=1629.0), HTML(value='')))" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "epoch: 14\n", "train ppl: 234.50125342517168\n", "train_test ppl: 265.61149027211843\n", "\n" ] } ], "source": [ "hppl_train = []\n", "hppl_train_test = []\n", "for epoch in range(EPOCHS):\n", " \n", " batches = 0\n", " loss_sum =0\n", " acc_score = 0\n", " lm.train()\n", " for i in tqdm(range(0, len(train_ids)-BATCH_SIZE+1, BATCH_SIZE)):\n", " X = train_ids[i:i+BATCH_SIZE,:NGRAMS-1]\n", " Y = train_ids[i:i+BATCH_SIZE,NGRAMS-1]\n", " predictions = lm(X)\n", " loss = criterion(predictions,Y)\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", " loss_sum += loss.item()\n", " batches += 1\n", " \n", " #ppl train\n", " lm.eval()\n", " batches = 0\n", " loss_sum =0\n", " acc_score = 0\n", " for i in range(0, len(train_ids)-BATCH_SIZE+1, BATCH_SIZE):\n", " X = train_ids[i:i+BATCH_SIZE,:NGRAMS-1]\n", " Y = train_ids[i:i+BATCH_SIZE,NGRAMS-1]\n", " predictions = lm(X)\n", " loss = criterion(predictions,Y)\n", " loss_sum += loss.item()\n", " batches += 1\n", "\n", " ppl_train = np.exp(loss_sum / batches)\n", "\n", " #ppl train test\n", " lm.eval()\n", " batches = 0\n", " loss_sum =0\n", " acc_score = 0\n", " for i in range(0, len(train_test_ids)-BATCH_SIZE+1, BATCH_SIZE):\n", " X = train_test_ids[i:i+BATCH_SIZE,:NGRAMS-1]\n", " Y = train_test_ids[i:i+BATCH_SIZE,NGRAMS-1]\n", " predictions = lm(X)\n", " loss = criterion(predictions,Y)\n", " loss_sum += loss.item()\n", " batches += 1\n", "\n", " ppl_train_test = np.exp(loss_sum / batches)\n", " \n", " hppl_train.append(ppl_train)\n", " hppl_train_test.append(ppl_train_test) \n", " print('epoch: ', epoch)\n", " print('train ppl: ', ppl_train)\n", " print('train_test ppl: ', ppl_train_test)\n", " print()" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "człowiek\n", "i\n", "nagle\n", ".—\n", "nie\n", "będzie\n", ",\n", "nie\n", "jestem\n", "pewna\n", "do\n", "niego\n", "i\n", "nie\n", ",\n", "jak\n", "pan\n", ";\n", "jest\n", ".\n", "a\n", "jeżeli\n", ",\n", "nawet\n", "po\n", ".\n", "na\n", "lewo\n", "po\n", "kilka\n" ] } ], "source": [ "#'Gości' i 'Lalka'\n", "tokenized = list(tokenize('Gości innych nie widział oprócz spółleśników'))\n", "tokenized = [token.lower() for token in tokenized]\n", "\n", "ids = []\n", "for word in tokenized:\n", " if word in vocab_stoi:\n", " ids.append(vocab_stoi[word])\n", " else:\n", " ids.append(vocab_stoi[''])\n", "\n", "lm.eval()\n", "\n", "ids = torch.tensor(ids, dtype = torch.long, device = device)\n", "\n", "preds= lm(ids.unsqueeze(0))\n", "\n", "vocab_itos[torch.argmax(torch.softmax(preds,1),1).item()]\n", "\n", "tokenized = list(tokenize('Lalka'))\n", "tokenized = [token.lower() for token in tokenized]\n", "\n", "ids = []\n", "for word in tokenized:\n", " if word in vocab_stoi:\n", " ids.append(vocab_stoi[word])\n", " else:\n", " ids.append(vocab_stoi[''])\n", "ids = torch.tensor([ids], dtype = torch.long, device = device)\n", "\n", "candidates_number = 10\n", "for i in range(30):\n", " preds= lm(ids)\n", " candidates = torch.topk(torch.softmax(preds,1),candidates_number)[1][0].cpu().numpy()\n", " candidate = 15001\n", " while candidate > 15000:\n", " candidate = candidates[np.random.randint(candidates_number)]\n", " print(vocab_itos[candidate])\n", " ids = torch.cat((ids, torch.tensor([[candidate]], device = device)), 1)" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ ".\n", "o\n", "mnie\n", "nie\n", "było\n", ".—\n", "ani\n", ".\n", "jest\n", "jak\n", ".\n", "ale\n", "co\n", "pan\n", "nie\n", "obchodzi\n", "!\n", "nawet\n", "nie\n", "jest\n", "!\n", "?\n", ".\n", "i\n", "jeszcze\n", "do\n", ".\n", "po\n", "co\n", "do\n", "pani\n", ",\n", "który\n", ",\n", "nawet\n", ",\n", "jak\n", "ona\n", "do\n", "panny\n", ";\n", "i\n", "nawet\n", ":\n", "o\n", "co\n", "na\n", "myśl\n", "!\n", ".\n", "po\n", ",\n", "jak\n", "i\n", "ja\n", "?\n", ".\n", "a\n", "jeżeli\n", "nie\n", "o\n", "o\n", "?\n", "po\n", "nie\n", "był\n", "pani\n", ".—\n", ".\n", "pan\n", "mnie\n", "nie\n", ",\n", "nawet\n", "mnie\n", "o\n", ".—\n", ".\n", "nie\n", "jestem\n", ",\n", "jak\n", "on\n", ",\n", "jak\n", "nie\n", ",\n", "nawet\n", "i\n", "nie\n", ".\n", "a\n", "jeżeli\n", "co\n", "?\n", "i\n", "kto\n", "?\n", "!\n", "na\n", "jego\n", "ostrzyżonej\n", ")\n", "?\n", "do\n", "mnie\n", "i\n", "do\n", "na\n", "mnie\n", "i\n", "po\n", "co\n", "i\n", "jeszcze\n", ":\n", "czy\n", "nie\n", ",\n", "pani\n", "dobrodziejko\n", "!\n", "na\n", "nie\n", "i\n", "po\n", "jego\n", "na\n", "lewo\n", ",\n", "ale\n", ",\n", "który\n", "na\n", "niego\n", "nie\n", "było\n", ";\n", "nie\n", "i\n", "nie\n", "ma\n", "na\n", ",\n", "a\n", "pani\n", "nie\n", "mam\n", "?\n", ".\n", "nie\n", "może\n", "na\n", "mnie\n", "i\n", "jeszcze\n", "nie\n", "mam\n", "?\n", "ale\n", ",\n", "i\n", "już\n", ",\n", "nie\n", "mam\n", ".\n", "i\n", "cóż\n", "!\n", ")\n", ".\n", "nie\n", "jestem\n", "o\n", "mnie\n", "nie\n", "i\n", "nic\n", "?\n", "i\n", "ja\n", ".—\n", "nie\n", "chcę\n", ",\n", "na\n", "lewo\n", "nie\n", "było\n", "na\n", "jej\n", ",\n", "nie\n", "na\n", "jej\n", "nie\n", ",\n", "ażeby\n", "jak\n", ".\n", "ale\n", "nie\n", "było\n", "o\n", "nią\n", "i\n", ",\n", "a\n", "nawet\n", "nie\n", "jest\n", ".\n", "nie\n", "chcę\n", ".\n", "a\n", "co\n", "pan\n", "do\n", "niej\n", ",\n", "który\n", ",\n", "na\n", "jego\n", ".\n", "była\n", "już\n", ":\n", ",\n", "i\n", "nawet\n", "go\n", "o\n", "nim\n", ";\n", "o\n", "jej\n", "nie\n", "było\n", "na\n", "niego\n", "albo\n", "i\n", ".\n", "gdy\n", "go\n", ".—\n", "co\n", "mi\n", "do\n", "domu\n", "?\n", "albo\n", "i\n", ",\n", "a\n", "pan\n", ",\n", "panie\n", "nie\n", "!\n", "!\n", "!\n", "ja\n", "i\n", "na\n", "jej\n", "ochronę\n", "do\n", ",\n", "co\n", "mnie\n", "nie\n", "mam\n", ".—\n", "może\n", ",\n", "a\n", "nie\n", "ma\n", "na\n", "mnie\n", "nie\n", ",\n", "ani\n", "i\n", "nawet\n", "nie\n", "na\n", "nic\n", "!\n", ".\n", "po\n", "chwili\n", ".—\n", "nie\n", "ma\n", "pan\n", "ignacy\n", ".—\n", "może\n", "mnie\n", "nie\n", "?\n", "nawet\n", "?\n", "po\n", "chwili\n", ".\n", "nie\n", "był\n", ";\n", "na\n", "myśl\n", ",\n", "a\n", "nawet\n", "mnie\n", "?\n", "do\n", "na\n", "nią\n", ";\n", "i\n", "jeszcze\n", "jak\n", "on\n", ".\n", "i\n", "nawet\n", "do\n", "końca\n", "na\n", "jego\n", "nie\n", "i\n", "nawet\n", "do\n", "domu\n", "?\n", "i\n", "o\n", "co\n", "dzień\n", "do\n", "pani\n", "?\n", "a\n", ",\n", "czy\n", "nie\n", "chcę\n", ".—\n", "ja\n", "?\n", "i\n", "o\n", ".\n", "ja\n", ",\n", "bo\n", "nie\n", "ma\n", "być\n", "?\n", ",\n", "nie\n", "mam\n", "na\n", "co\n", ".—\n", ",\n", "ja\n", "?\n", ",\n", "co\n", "?\n", ")\n", "do\n", "pana\n", ".\n", "na\n", "lewo\n", ".\n", "nie\n", "na\n", "nic\n", ".\n", "ale\n", "nie\n", ",\n", "a\n", "ja\n", "?\n", ",\n", "a\n", "co\n", "do\n", "pani\n", ".\n", "była\n", "do\n", "pani\n", "meliton\n", ":\n", "albo\n", "o\n", ",\n", "ażeby\n", ",\n", "ale\n", "co\n", ",\n", "jak\n", "ona\n", "na\n", "niego\n", ";\n", ".\n", "ale\n", "jeszcze\n", "na\n", ",\n", "na\n", "jego\n", "miejscu\n", "i\n", "była\n", ".—\n", "i\n", "ja\n", ".—\n", "na\n", "nią\n", "nie\n", "było\n", ".—\n", "co\n", "do\n", "mnie\n", ",\n", "ale\n", "nawet\n", ",\n", "do\n", "licha\n", "na\n", "myśl\n", "i\n", "do\n", ".—\n", "o\n", "mnie\n", "pan\n", "na\n", "co\n", "dzień\n", "na\n", "głowie\n", ".—\n", "co\n", ".\n", "nie\n", "jest\n", "ci\n", ".—\n", "pan\n", ".\n", "nie\n" ] } ], "source": [ "#dev0 pred\n", "\n", "with open(\"dev-0/in.tsv\", \"r\", encoding='utf-8') as dev_path:\n", " nr_of_dev_lines = len(dev_path.readlines())\n", "\n", "with open(\"dev-0/out.tsv\", \"w\", encoding='utf-8') as out_dev_file:\n", " for i in range(nr_of_dev_lines):\n", " preds= lm(ids)\n", " candidates = torch.topk(torch.softmax(preds,1),candidates_number)[1][0].cpu().numpy()\n", " candidate = 15001\n", " while candidate > 15000:\n", " candidate = candidates[np.random.randint(candidates_number)]\n", " print(vocab_itos[candidate])\n", " ids = torch.cat((ids, torch.tensor([[candidate]], device = device)), 1)\n", " out_dev_file.write(vocab_itos[candidate] + '\\n')\n" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ ",\n", "a\n", "ja\n", ".\n", "na\n", "co\n", "na\n", "mnie\n", "i\n", "kto\n", ",\n", "ale\n", "nawet\n", "na\n", "mnie\n", "!\n", "co\n", "ja\n", ".—\n", "już\n", "?\n", "!\n", ")\n", "i\n", "pan\n", "na\n", "myśl\n", ";\n", ",\n", "a\n", "nawet\n", "nie\n", ",\n", "jak\n", "pan\n", "na\n", "mnie\n", "na\n", ",\n", "i\n", "o\n", "co\n", "ja\n", "nie\n", "chcę\n", ".—\n", ",\n", "nie\n", "mam\n", "?\n", "?\n", "nie\n", ".\n", "pani\n", "nie\n", "jest\n", "na\n", "co\n", "nie\n", "może\n", "i\n", "cóż\n", "nie\n", ".\n", "a\n", "jeżeli\n", "jak\n", "ona\n", "!\n", "na\n", "dole\n", ".\n", "nie\n", "był\n", "pan\n", ".\n", "nie\n", "jest\n", "jeszcze\n", "jak\n", "pani\n", "?\n", "i\n", "o\n", "?\n", "po\n", "?\n", "po\n", "co\n", "dzień\n", "?\n", "na\n", ",\n", "co\n", "pan\n", "do\n", "niego\n", "na\n", "głowie\n", ".—\n", ".\n", "nie\n", "był\n", ".\n", "na\n", "myśl\n", ";\n", "i\n", "ja\n", ".\n", "na\n", "lewo\n", ";\n", "była\n", "go\n", ",\n", "na\n", "jej\n", ".—\n", "o\n", "!\n", "?\n", "na\n", "co\n", "!\n", ")\n", "do\n", "głowy\n", ".\n", "i\n", "nawet\n", "do\n", "niej\n", ".\n", "nie\n", "był\n", ";\n", "o\n", "ile\n", "o\n", "jego\n", "o\n", ";\n", "ale\n", "pan\n", "ignacy\n", ".—\n", "nie\n", "ma\n", "pan\n", "do\n", ".\n", "ja\n", "do\n", "mego\n", "i\n", "nie\n", "będzie\n", "o\n", "mnie\n", "i\n", "już\n", ".\n", "o\n", "co\n", "pan\n", "ignacy\n", "?\n", "na\n", "którym\n", ",\n", "kiedy\n", "go\n", "na\n", "jej\n", ";\n", "ale\n", "co\n", ",\n", "a\n", "co\n", "pan\n", "?\n", "i\n", "kto\n", "mu\n", "pan\n", ",\n", "co\n", "?\n", "o\n", ",\n", "i\n", "kto\n", "by\n", "mnie\n", "do\n", "głowy\n", ".—\n", "a\n", "!\n", "nawet\n", "o\n", "niej\n", "na\n", "myśl\n", "?\n", "i\n", "już\n", "do\n", ".\n", "nie\n", "na\n", "mnie\n", "nie\n", "mam\n", ".\n", "była\n", "już\n", ".\n", "(\n", ",\n", "nie\n", "!\n", ",\n", "jak\n", "on\n", "mnie\n", ".—\n", "pan\n", ".\n", "(\n", "może\n", "na\n", "nie\n", "było\n", "i\n", ",\n", "który\n", "by\n", "mu\n", "nie\n", ".\n", "i\n", "dopiero\n", ".\n", "a\n", ",\n", "jak\n", "ja\n", ",\n", "na\n", "którym\n", "?\n", "a\n", "jeżeli\n", "jest\n", "bardzo\n", "?\n", "!\n", ",\n", "bo\n", "już\n", ".\n", "nie\n", "chcę\n", "go\n", "do\n", "paryża\n", ".—\n", "co\n", "dzień\n", "pan\n", "nie\n", ".\n", "?\n", "co\n", "na\n", "myśl\n", "!\n", ",\n", "a\n", "może\n", "jeszcze\n", "na\n", "niego\n", ",\n", "nie\n", "ma\n", ",\n", "a\n", "pan\n", "nie\n", "będzie\n", ".—\n", "nic\n", "mnie\n", "pan\n", ".\n", "*\n", ".\n", "ja\n", "nie\n", ",\n", "pani\n", "dobrodziejko\n", ".—\n", "i\n", "cóż\n", ".\n", "pan\n", "nie\n", "jadł\n", "na\n", "nich\n", "!\n", ";\n", "na\n", "lewo\n", "na\n", "mnie\n", "i\n", "na\n", "nogi\n", "?\n", ".—\n", "nie\n", "chcę\n", "?\n", ",\n", "co\n", "by\n", "?\n", "!\n", "o\n", "?\n", "po\n", "i\n", "nawet\n", ",\n", "jak\n", "ja\n", ".\n", "ale\n", "o\n", "jej\n", "!\n", ",\n", "jak\n", "ja\n", "już\n", "nic\n", "!\n", ")\n", "!\n", "cha\n", ",\n", "ale\n", "nawet\n", "do\n", "głowy\n", "na\n", ",\n", "nie\n", "mógł\n", "nawet\n", "nie\n", "mógł\n", "do\n", "niego\n", "nie\n", "na\n", "mnie\n", "?\n", ")\n", ",\n", "ale\n", "jeszcze\n", ".\n", "po\n", ".\n", "o\n", "mnie\n", "na\n", "jego\n", "na\n", "myśl\n", "i\n", "nawet\n", "na\n", "lewo\n", "na\n", "głowie\n", "na\n", "górę\n", "i\n", "po\n", "otworzeniu\n", ";\n", "ale\n", "co\n", "do\n", "na\n", "jego\n", ".—\n", "a\n", "pan\n", "i\n", "co\n", ".\n", "jest\n", "pan\n", "ignacy\n", "do\n", "paryża\n", "nie\n", "mam\n", ".\n", "a\n", "jeżeli\n", "na\n", "jej\n", "?\n", ".\n", "o\n", "nie\n", "i\n", "nie\n", ".\n", "o\n", "jego\n", "po\n", "pokoju\n", ",\n", "jak\n", "ja\n", "już\n", ":\n", "od\n", "na\n", "do\n", ";\n", "ale\n", "nawet\n", "o\n", "niej\n", "nie\n", "jest\n", ",\n", "ale\n", ",\n", "jak\n", ",\n", "na\n", "jej\n", ".\n", "nie\n", "był\n", "ani\n", "ani\n", "do\n", ",\n", "a\n", "na\n", "nią\n", ":\n", "nawet\n", "co\n", "nie\n", ".\n", "na\n" ] } ], "source": [ "#testA pred\n", "\n", "with open(\"test-A/in.tsv\", \"r\", encoding='utf-8') as test_a_path:\n", " nr_of_test_a_lines = len(test_a_path.readlines())\n", "with open(\"test-A/out.tsv\", \"w\", encoding='utf-8') as out_test_file:\n", " for i in range(nr_of_dev_lines):\n", " preds= lm(ids)\n", " candidates = torch.topk(torch.softmax(preds,1),candidates_number)[1][0].cpu().numpy()\n", " candidate = 15001\n", " while candidate > 15000:\n", " candidate = candidates[np.random.randint(candidates_number)]\n", " print(vocab_itos[candidate])\n", " ids = torch.cat((ids, torch.tensor([[candidate]], device = device)), 1)\n", " out_test_file.write(vocab_itos[candidate] + '\\n')" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.3" } }, "nbformat": 4, "nbformat_minor": 4 }