forked from kubapok/lalka-lm
1725 lines
34 KiB
Plaintext
1725 lines
34 KiB
Plaintext
{
|
|
"cells": [
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 18,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"import nltk\n",
|
|
"import torch\n",
|
|
"import pandas as pd\n",
|
|
"import csv\n",
|
|
"from sklearn.model_selection import train_test_split\n",
|
|
"from nltk.tokenize import word_tokenize as tokenize\n",
|
|
"from tqdm.notebook import tqdm\n",
|
|
"import numpy as np"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stderr",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"[nltk_data] Downloading package punkt to /home/ubuntu/nltk_data...\n",
|
|
"[nltk_data] Package punkt is already up-to-date!\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"text/plain": [
|
|
"True"
|
|
]
|
|
},
|
|
"execution_count": 3,
|
|
"metadata": {},
|
|
"output_type": "execute_result"
|
|
}
|
|
],
|
|
"source": [
|
|
"#downloads\n",
|
|
"nltk.download('punkt')"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 4,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"Using cpu device\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"#settings\n",
|
|
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
|
|
"print('Using {} device'.format(device))\n",
|
|
"criterion = torch.nn.CrossEntropyLoss()\n",
|
|
"BATCH_SIZE = 128\n",
|
|
"EPOCHS = 15\n",
|
|
"NGRAMS = 5"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 14,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"28558\n",
|
|
"15005\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"#training data prepare\n",
|
|
"train_data = pd.read_csv('train/train.tsv', header=None, error_bad_lines=False, quoting=csv.QUOTE_NONE, sep='\\t')\n",
|
|
"train_data = train_data[0]\n",
|
|
"train_set, train_test_set = train_test_split(train_data, test_size = 0.2)\n",
|
|
"with open(\"train/train_set.tsv\", \"w\", encoding='utf-8') as out_train_set:\n",
|
|
" for i in train_set:\n",
|
|
" out_train_set.write(i)\n",
|
|
"with open(\"train/train_test_set.tsv\", \"w\", encoding='utf-8') as out_train_test_set:\n",
|
|
" for i in train_test_set:\n",
|
|
" out_train_test_set.write(i)\n",
|
|
"\n",
|
|
"train_set_tok = list(tokenize(open('train/train_set.tsv').read()))\n",
|
|
"train_set_tok = [line.lower() for line in train_set_tok]\n",
|
|
"\n",
|
|
"vocab_itos = sorted(set(train_set_tok))\n",
|
|
"print(len(vocab_itos))\n",
|
|
"\n",
|
|
"vocab_itos = vocab_itos[:15005]\n",
|
|
"vocab_itos[15001] = \"<UNK>\"\n",
|
|
"vocab_itos[15002] = \"<BOS>\"\n",
|
|
"vocab_itos[15003] = \"<EOS>\"\n",
|
|
"vocab_itos[15004] = \"<PAD>\"\n",
|
|
"\n",
|
|
"print(len(vocab_itos))\n",
|
|
"\n",
|
|
"vocab_stoi = dict()\n",
|
|
"for i, token in enumerate(vocab_itos):\n",
|
|
" vocab_stoi[token] = i\n",
|
|
"\n",
|
|
"\n",
|
|
"\n",
|
|
"train_ids = [vocab_stoi['<PAD>']] * (NGRAMS-1) + [vocab_stoi['<BOS>']]\n",
|
|
"for token in train_set_tok:\n",
|
|
" try:\n",
|
|
" train_ids.append(vocab_stoi[token])\n",
|
|
" except KeyError:\n",
|
|
" train_ids.append(vocab_stoi['<UNK>'])\n",
|
|
"train_ids.append(vocab_stoi['<EOS>'])\n",
|
|
"\n",
|
|
"\n",
|
|
"samples = []\n",
|
|
"for i in range(len(train_ids)-NGRAMS):\n",
|
|
" samples.append(train_ids[i:i+NGRAMS])\n",
|
|
"train_ids = torch.tensor(samples,device=device)\n",
|
|
"\n",
|
|
"\n",
|
|
"train_test_set_tok = list(tokenize(open('train/train_test_set.tsv').read()))\n",
|
|
"train_test_set_tok = [line.lower() for line in train_test_set_tok]\n",
|
|
"\n",
|
|
"train_test_ids = [vocab_stoi['<PAD>']] * (NGRAMS-1) + [vocab_stoi['<BOS>']]\n",
|
|
"for token in train_test_set_tok:\n",
|
|
" try:\n",
|
|
" train_test_ids.append(vocab_stoi[token])\n",
|
|
" except KeyError:\n",
|
|
" train_test_ids.append(vocab_stoi['<UNK>'])\n",
|
|
"train_test_ids.append(vocab_stoi['<EOS>'])\n",
|
|
"\n",
|
|
"\n",
|
|
"samples = []\n",
|
|
"for i in range(len(train_test_ids)-NGRAMS):\n",
|
|
" samples.append(train_test_ids[i:i+NGRAMS])\n",
|
|
"train_test_ids = torch.tensor(samples, dtype=torch.long, device=device)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 15,
|
|
"metadata": {},
|
|
"outputs": [],
|
|
"source": [
|
|
"#GRU\n",
|
|
"class GRU(torch.nn.Module):\n",
|
|
"\n",
|
|
" def __init__(self):\n",
|
|
" super(GRU, self).__init__()\n",
|
|
" self.emb = torch.nn.Embedding(len(vocab_itos),100)\n",
|
|
" self.rec = torch.nn.GRU(100, 256, 1, batch_first = True)\n",
|
|
" self.fc1 = torch.nn.Linear( 256 ,len(vocab_itos))\n",
|
|
" self.dropout = torch.nn.Dropout(0.5)\n",
|
|
"\n",
|
|
" def forward(self, x):\n",
|
|
" emb = self.emb(x)\n",
|
|
" #emb = self.dropout(emb)\n",
|
|
" output, h_n = self.rec(emb)\n",
|
|
" hidden = h_n.squeeze(0)\n",
|
|
" out = self.fc1(hidden)\n",
|
|
" out = self.dropout(out)\n",
|
|
" return out\n",
|
|
"lm = GRU().to(device)\n",
|
|
"optimizer = torch.optim.Adam(lm.parameters(),lr=0.0001)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 20,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "56b8d3f9424a4a6ca15ea27c705ead10",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"HBox(children=(FloatProgress(value=0.0, max=1629.0), HTML(value='')))"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"epoch: 0\n",
|
|
"train ppl: 429.60890594777385\n",
|
|
"train_test ppl: 354.7605940026038\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "bf252622fa70442aa21dc391275818d3",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"HBox(children=(FloatProgress(value=0.0, max=1629.0), HTML(value='')))"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"epoch: 1\n",
|
|
"train ppl: 385.04263303807164\n",
|
|
"train_test ppl: 320.5323274780826\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "700fad78591b4cf18ac03e48628c4535",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"HBox(children=(FloatProgress(value=0.0, max=1629.0), HTML(value='')))"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"epoch: 2\n",
|
|
"train ppl: 388.15715746591627\n",
|
|
"train_test ppl: 331.5143312260392\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "d1b46286cde6423195b0e0321cf4cb37",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"HBox(children=(FloatProgress(value=0.0, max=1629.0), HTML(value='')))"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"epoch: 3\n",
|
|
"train ppl: 364.4566197255965\n",
|
|
"train_test ppl: 316.9918140368464\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "df3ff22f10cd40bb9758da63481e99e2",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"HBox(children=(FloatProgress(value=0.0, max=1629.0), HTML(value='')))"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"epoch: 4\n",
|
|
"train ppl: 344.1713452631125\n",
|
|
"train_test ppl: 306.67499426384535\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "20dd67a95f81488dad61194310b0c5b1",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"HBox(children=(FloatProgress(value=0.0, max=1629.0), HTML(value='')))"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"epoch: 5\n",
|
|
"train ppl: 325.7237671473614\n",
|
|
"train_test ppl: 295.83423173746667\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "306a7f0b7bd340cbafe5ecc784a1738e",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"HBox(children=(FloatProgress(value=0.0, max=1629.0), HTML(value='')))"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"epoch: 6\n",
|
|
"train ppl: 323.8838574773216\n",
|
|
"train_test ppl: 302.95495879615413\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "972f556564a44554880d446cc0a3b126",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"HBox(children=(FloatProgress(value=0.0, max=1629.0), HTML(value='')))"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"epoch: 7\n",
|
|
"train ppl: 313.13238735049896\n",
|
|
"train_test ppl: 300.0722307805052\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "740454f9d4544c1bbdd6411a13f9ad75",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"HBox(children=(FloatProgress(value=0.0, max=1629.0), HTML(value='')))"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"epoch: 8\n",
|
|
"train ppl: 308.2248282795148\n",
|
|
"train_test ppl: 303.25779664571974\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "73a03968179942bebfecc8f35928c016",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"HBox(children=(FloatProgress(value=0.0, max=1629.0), HTML(value='')))"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"epoch: 9\n",
|
|
"train ppl: 293.68307666273853\n",
|
|
"train_test ppl: 295.00145166486533\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "f6b3bb79ccd84e06909e91a7e6678ee6",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"HBox(children=(FloatProgress(value=0.0, max=1629.0), HTML(value='')))"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"epoch: 10\n",
|
|
"train ppl: 279.2453691179102\n",
|
|
"train_test ppl: 287.8307587065576\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "a5ba1fd4d2434b18a41955f46e8b4c82",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"HBox(children=(FloatProgress(value=0.0, max=1629.0), HTML(value='')))"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"epoch: 11\n",
|
|
"train ppl: 267.2034758169644\n",
|
|
"train_test ppl: 282.18074183208086\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "a08ea62337764cd4b72b25b14ea609a2",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"HBox(children=(FloatProgress(value=0.0, max=1629.0), HTML(value='')))"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"epoch: 12\n",
|
|
"train ppl: 260.65159391269935\n",
|
|
"train_test ppl: 281.92398288442536\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "0f7ebcb5d21a47e78875a829e71fc0c7",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"HBox(children=(FloatProgress(value=0.0, max=1629.0), HTML(value='')))"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"epoch: 13\n",
|
|
"train ppl: 246.21807765812747\n",
|
|
"train_test ppl: 271.8481103799856\n",
|
|
"\n"
|
|
]
|
|
},
|
|
{
|
|
"data": {
|
|
"application/vnd.jupyter.widget-view+json": {
|
|
"model_id": "29773930d3c246079b26e9e6d4da84fd",
|
|
"version_major": 2,
|
|
"version_minor": 0
|
|
},
|
|
"text/plain": [
|
|
"HBox(children=(FloatProgress(value=0.0, max=1629.0), HTML(value='')))"
|
|
]
|
|
},
|
|
"metadata": {},
|
|
"output_type": "display_data"
|
|
},
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"\n",
|
|
"epoch: 14\n",
|
|
"train ppl: 234.50125342517168\n",
|
|
"train_test ppl: 265.61149027211843\n",
|
|
"\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"hppl_train = []\n",
|
|
"hppl_train_test = []\n",
|
|
"for epoch in range(EPOCHS):\n",
|
|
" \n",
|
|
" batches = 0\n",
|
|
" loss_sum =0\n",
|
|
" acc_score = 0\n",
|
|
" lm.train()\n",
|
|
" for i in tqdm(range(0, len(train_ids)-BATCH_SIZE+1, BATCH_SIZE)):\n",
|
|
" X = train_ids[i:i+BATCH_SIZE,:NGRAMS-1]\n",
|
|
" Y = train_ids[i:i+BATCH_SIZE,NGRAMS-1]\n",
|
|
" predictions = lm(X)\n",
|
|
" loss = criterion(predictions,Y)\n",
|
|
" optimizer.zero_grad()\n",
|
|
" loss.backward()\n",
|
|
" optimizer.step()\n",
|
|
" loss_sum += loss.item()\n",
|
|
" batches += 1\n",
|
|
" \n",
|
|
" #ppl train\n",
|
|
" lm.eval()\n",
|
|
" batches = 0\n",
|
|
" loss_sum =0\n",
|
|
" acc_score = 0\n",
|
|
" for i in range(0, len(train_ids)-BATCH_SIZE+1, BATCH_SIZE):\n",
|
|
" X = train_ids[i:i+BATCH_SIZE,:NGRAMS-1]\n",
|
|
" Y = train_ids[i:i+BATCH_SIZE,NGRAMS-1]\n",
|
|
" predictions = lm(X)\n",
|
|
" loss = criterion(predictions,Y)\n",
|
|
" loss_sum += loss.item()\n",
|
|
" batches += 1\n",
|
|
"\n",
|
|
" ppl_train = np.exp(loss_sum / batches)\n",
|
|
"\n",
|
|
" #ppl train test\n",
|
|
" lm.eval()\n",
|
|
" batches = 0\n",
|
|
" loss_sum =0\n",
|
|
" acc_score = 0\n",
|
|
" for i in range(0, len(train_test_ids)-BATCH_SIZE+1, BATCH_SIZE):\n",
|
|
" X = train_test_ids[i:i+BATCH_SIZE,:NGRAMS-1]\n",
|
|
" Y = train_test_ids[i:i+BATCH_SIZE,NGRAMS-1]\n",
|
|
" predictions = lm(X)\n",
|
|
" loss = criterion(predictions,Y)\n",
|
|
" loss_sum += loss.item()\n",
|
|
" batches += 1\n",
|
|
"\n",
|
|
" ppl_train_test = np.exp(loss_sum / batches)\n",
|
|
" \n",
|
|
" hppl_train.append(ppl_train)\n",
|
|
" hppl_train_test.append(ppl_train_test) \n",
|
|
" print('epoch: ', epoch)\n",
|
|
" print('train ppl: ', ppl_train)\n",
|
|
" print('train_test ppl: ', ppl_train_test)\n",
|
|
" print()"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 21,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
"człowiek\n",
|
|
"i\n",
|
|
"nagle\n",
|
|
".—\n",
|
|
"nie\n",
|
|
"będzie\n",
|
|
",\n",
|
|
"nie\n",
|
|
"jestem\n",
|
|
"pewna\n",
|
|
"do\n",
|
|
"niego\n",
|
|
"i\n",
|
|
"nie\n",
|
|
",\n",
|
|
"jak\n",
|
|
"pan\n",
|
|
";\n",
|
|
"jest\n",
|
|
".\n",
|
|
"a\n",
|
|
"jeżeli\n",
|
|
",\n",
|
|
"nawet\n",
|
|
"po\n",
|
|
".\n",
|
|
"na\n",
|
|
"lewo\n",
|
|
"po\n",
|
|
"kilka\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"#'Gości' i 'Lalka'\n",
|
|
"tokenized = list(tokenize('Gości innych nie widział oprócz spółleśników'))\n",
|
|
"tokenized = [token.lower() for token in tokenized]\n",
|
|
"\n",
|
|
"ids = []\n",
|
|
"for word in tokenized:\n",
|
|
" if word in vocab_stoi:\n",
|
|
" ids.append(vocab_stoi[word])\n",
|
|
" else:\n",
|
|
" ids.append(vocab_stoi['<UNK>'])\n",
|
|
"\n",
|
|
"lm.eval()\n",
|
|
"\n",
|
|
"ids = torch.tensor(ids, dtype = torch.long, device = device)\n",
|
|
"\n",
|
|
"preds= lm(ids.unsqueeze(0))\n",
|
|
"\n",
|
|
"vocab_itos[torch.argmax(torch.softmax(preds,1),1).item()]\n",
|
|
"\n",
|
|
"tokenized = list(tokenize('Lalka'))\n",
|
|
"tokenized = [token.lower() for token in tokenized]\n",
|
|
"\n",
|
|
"ids = []\n",
|
|
"for word in tokenized:\n",
|
|
" if word in vocab_stoi:\n",
|
|
" ids.append(vocab_stoi[word])\n",
|
|
" else:\n",
|
|
" ids.append(vocab_stoi['<UNK>'])\n",
|
|
"ids = torch.tensor([ids], dtype = torch.long, device = device)\n",
|
|
"\n",
|
|
"candidates_number = 10\n",
|
|
"for i in range(30):\n",
|
|
" preds= lm(ids)\n",
|
|
" candidates = torch.topk(torch.softmax(preds,1),candidates_number)[1][0].cpu().numpy()\n",
|
|
" candidate = 15001\n",
|
|
" while candidate > 15000:\n",
|
|
" candidate = candidates[np.random.randint(candidates_number)]\n",
|
|
" print(vocab_itos[candidate])\n",
|
|
" ids = torch.cat((ids, torch.tensor([[candidate]], device = device)), 1)"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 22,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
".\n",
|
|
"o\n",
|
|
"mnie\n",
|
|
"nie\n",
|
|
"było\n",
|
|
".—\n",
|
|
"ani\n",
|
|
".\n",
|
|
"jest\n",
|
|
"jak\n",
|
|
".\n",
|
|
"ale\n",
|
|
"co\n",
|
|
"pan\n",
|
|
"nie\n",
|
|
"obchodzi\n",
|
|
"!\n",
|
|
"nawet\n",
|
|
"nie\n",
|
|
"jest\n",
|
|
"!\n",
|
|
"?\n",
|
|
".\n",
|
|
"i\n",
|
|
"jeszcze\n",
|
|
"do\n",
|
|
".\n",
|
|
"po\n",
|
|
"co\n",
|
|
"do\n",
|
|
"pani\n",
|
|
",\n",
|
|
"który\n",
|
|
",\n",
|
|
"nawet\n",
|
|
",\n",
|
|
"jak\n",
|
|
"ona\n",
|
|
"do\n",
|
|
"panny\n",
|
|
";\n",
|
|
"i\n",
|
|
"nawet\n",
|
|
":\n",
|
|
"o\n",
|
|
"co\n",
|
|
"na\n",
|
|
"myśl\n",
|
|
"!\n",
|
|
".\n",
|
|
"po\n",
|
|
",\n",
|
|
"jak\n",
|
|
"i\n",
|
|
"ja\n",
|
|
"?\n",
|
|
".\n",
|
|
"a\n",
|
|
"jeżeli\n",
|
|
"nie\n",
|
|
"o\n",
|
|
"o\n",
|
|
"?\n",
|
|
"po\n",
|
|
"nie\n",
|
|
"był\n",
|
|
"pani\n",
|
|
".—\n",
|
|
".\n",
|
|
"pan\n",
|
|
"mnie\n",
|
|
"nie\n",
|
|
",\n",
|
|
"nawet\n",
|
|
"mnie\n",
|
|
"o\n",
|
|
".—\n",
|
|
".\n",
|
|
"nie\n",
|
|
"jestem\n",
|
|
",\n",
|
|
"jak\n",
|
|
"on\n",
|
|
",\n",
|
|
"jak\n",
|
|
"nie\n",
|
|
",\n",
|
|
"nawet\n",
|
|
"i\n",
|
|
"nie\n",
|
|
".\n",
|
|
"a\n",
|
|
"jeżeli\n",
|
|
"co\n",
|
|
"?\n",
|
|
"i\n",
|
|
"kto\n",
|
|
"?\n",
|
|
"!\n",
|
|
"na\n",
|
|
"jego\n",
|
|
"ostrzyżonej\n",
|
|
")\n",
|
|
"?\n",
|
|
"do\n",
|
|
"mnie\n",
|
|
"i\n",
|
|
"do\n",
|
|
"na\n",
|
|
"mnie\n",
|
|
"i\n",
|
|
"po\n",
|
|
"co\n",
|
|
"i\n",
|
|
"jeszcze\n",
|
|
":\n",
|
|
"czy\n",
|
|
"nie\n",
|
|
",\n",
|
|
"pani\n",
|
|
"dobrodziejko\n",
|
|
"!\n",
|
|
"na\n",
|
|
"nie\n",
|
|
"i\n",
|
|
"po\n",
|
|
"jego\n",
|
|
"na\n",
|
|
"lewo\n",
|
|
",\n",
|
|
"ale\n",
|
|
",\n",
|
|
"który\n",
|
|
"na\n",
|
|
"niego\n",
|
|
"nie\n",
|
|
"było\n",
|
|
";\n",
|
|
"nie\n",
|
|
"i\n",
|
|
"nie\n",
|
|
"ma\n",
|
|
"na\n",
|
|
",\n",
|
|
"a\n",
|
|
"pani\n",
|
|
"nie\n",
|
|
"mam\n",
|
|
"?\n",
|
|
".\n",
|
|
"nie\n",
|
|
"może\n",
|
|
"na\n",
|
|
"mnie\n",
|
|
"i\n",
|
|
"jeszcze\n",
|
|
"nie\n",
|
|
"mam\n",
|
|
"?\n",
|
|
"ale\n",
|
|
",\n",
|
|
"i\n",
|
|
"już\n",
|
|
",\n",
|
|
"nie\n",
|
|
"mam\n",
|
|
".\n",
|
|
"i\n",
|
|
"cóż\n",
|
|
"!\n",
|
|
")\n",
|
|
".\n",
|
|
"nie\n",
|
|
"jestem\n",
|
|
"o\n",
|
|
"mnie\n",
|
|
"nie\n",
|
|
"i\n",
|
|
"nic\n",
|
|
"?\n",
|
|
"i\n",
|
|
"ja\n",
|
|
".—\n",
|
|
"nie\n",
|
|
"chcę\n",
|
|
",\n",
|
|
"na\n",
|
|
"lewo\n",
|
|
"nie\n",
|
|
"było\n",
|
|
"na\n",
|
|
"jej\n",
|
|
",\n",
|
|
"nie\n",
|
|
"na\n",
|
|
"jej\n",
|
|
"nie\n",
|
|
",\n",
|
|
"ażeby\n",
|
|
"jak\n",
|
|
".\n",
|
|
"ale\n",
|
|
"nie\n",
|
|
"było\n",
|
|
"o\n",
|
|
"nią\n",
|
|
"i\n",
|
|
",\n",
|
|
"a\n",
|
|
"nawet\n",
|
|
"nie\n",
|
|
"jest\n",
|
|
".\n",
|
|
"nie\n",
|
|
"chcę\n",
|
|
".\n",
|
|
"a\n",
|
|
"co\n",
|
|
"pan\n",
|
|
"do\n",
|
|
"niej\n",
|
|
",\n",
|
|
"który\n",
|
|
",\n",
|
|
"na\n",
|
|
"jego\n",
|
|
".\n",
|
|
"była\n",
|
|
"już\n",
|
|
":\n",
|
|
",\n",
|
|
"i\n",
|
|
"nawet\n",
|
|
"go\n",
|
|
"o\n",
|
|
"nim\n",
|
|
";\n",
|
|
"o\n",
|
|
"jej\n",
|
|
"nie\n",
|
|
"było\n",
|
|
"na\n",
|
|
"niego\n",
|
|
"albo\n",
|
|
"i\n",
|
|
".\n",
|
|
"gdy\n",
|
|
"go\n",
|
|
".—\n",
|
|
"co\n",
|
|
"mi\n",
|
|
"do\n",
|
|
"domu\n",
|
|
"?\n",
|
|
"albo\n",
|
|
"i\n",
|
|
",\n",
|
|
"a\n",
|
|
"pan\n",
|
|
",\n",
|
|
"panie\n",
|
|
"nie\n",
|
|
"!\n",
|
|
"!\n",
|
|
"!\n",
|
|
"ja\n",
|
|
"i\n",
|
|
"na\n",
|
|
"jej\n",
|
|
"ochronę\n",
|
|
"do\n",
|
|
",\n",
|
|
"co\n",
|
|
"mnie\n",
|
|
"nie\n",
|
|
"mam\n",
|
|
".—\n",
|
|
"może\n",
|
|
",\n",
|
|
"a\n",
|
|
"nie\n",
|
|
"ma\n",
|
|
"na\n",
|
|
"mnie\n",
|
|
"nie\n",
|
|
",\n",
|
|
"ani\n",
|
|
"i\n",
|
|
"nawet\n",
|
|
"nie\n",
|
|
"na\n",
|
|
"nic\n",
|
|
"!\n",
|
|
".\n",
|
|
"po\n",
|
|
"chwili\n",
|
|
".—\n",
|
|
"nie\n",
|
|
"ma\n",
|
|
"pan\n",
|
|
"ignacy\n",
|
|
".—\n",
|
|
"może\n",
|
|
"mnie\n",
|
|
"nie\n",
|
|
"?\n",
|
|
"nawet\n",
|
|
"?\n",
|
|
"po\n",
|
|
"chwili\n",
|
|
".\n",
|
|
"nie\n",
|
|
"był\n",
|
|
";\n",
|
|
"na\n",
|
|
"myśl\n",
|
|
",\n",
|
|
"a\n",
|
|
"nawet\n",
|
|
"mnie\n",
|
|
"?\n",
|
|
"do\n",
|
|
"na\n",
|
|
"nią\n",
|
|
";\n",
|
|
"i\n",
|
|
"jeszcze\n",
|
|
"jak\n",
|
|
"on\n",
|
|
".\n",
|
|
"i\n",
|
|
"nawet\n",
|
|
"do\n",
|
|
"końca\n",
|
|
"na\n",
|
|
"jego\n",
|
|
"nie\n",
|
|
"i\n",
|
|
"nawet\n",
|
|
"do\n",
|
|
"domu\n",
|
|
"?\n",
|
|
"i\n",
|
|
"o\n",
|
|
"co\n",
|
|
"dzień\n",
|
|
"do\n",
|
|
"pani\n",
|
|
"?\n",
|
|
"a\n",
|
|
",\n",
|
|
"czy\n",
|
|
"nie\n",
|
|
"chcę\n",
|
|
".—\n",
|
|
"ja\n",
|
|
"?\n",
|
|
"i\n",
|
|
"o\n",
|
|
".\n",
|
|
"ja\n",
|
|
",\n",
|
|
"bo\n",
|
|
"nie\n",
|
|
"ma\n",
|
|
"być\n",
|
|
"?\n",
|
|
",\n",
|
|
"nie\n",
|
|
"mam\n",
|
|
"na\n",
|
|
"co\n",
|
|
".—\n",
|
|
",\n",
|
|
"ja\n",
|
|
"?\n",
|
|
",\n",
|
|
"co\n",
|
|
"?\n",
|
|
")\n",
|
|
"do\n",
|
|
"pana\n",
|
|
".\n",
|
|
"na\n",
|
|
"lewo\n",
|
|
".\n",
|
|
"nie\n",
|
|
"na\n",
|
|
"nic\n",
|
|
".\n",
|
|
"ale\n",
|
|
"nie\n",
|
|
",\n",
|
|
"a\n",
|
|
"ja\n",
|
|
"?\n",
|
|
",\n",
|
|
"a\n",
|
|
"co\n",
|
|
"do\n",
|
|
"pani\n",
|
|
".\n",
|
|
"była\n",
|
|
"do\n",
|
|
"pani\n",
|
|
"meliton\n",
|
|
":\n",
|
|
"albo\n",
|
|
"o\n",
|
|
",\n",
|
|
"ażeby\n",
|
|
",\n",
|
|
"ale\n",
|
|
"co\n",
|
|
",\n",
|
|
"jak\n",
|
|
"ona\n",
|
|
"na\n",
|
|
"niego\n",
|
|
";\n",
|
|
".\n",
|
|
"ale\n",
|
|
"jeszcze\n",
|
|
"na\n",
|
|
",\n",
|
|
"na\n",
|
|
"jego\n",
|
|
"miejscu\n",
|
|
"i\n",
|
|
"była\n",
|
|
".—\n",
|
|
"i\n",
|
|
"ja\n",
|
|
".—\n",
|
|
"na\n",
|
|
"nią\n",
|
|
"nie\n",
|
|
"było\n",
|
|
".—\n",
|
|
"co\n",
|
|
"do\n",
|
|
"mnie\n",
|
|
",\n",
|
|
"ale\n",
|
|
"nawet\n",
|
|
",\n",
|
|
"do\n",
|
|
"licha\n",
|
|
"na\n",
|
|
"myśl\n",
|
|
"i\n",
|
|
"do\n",
|
|
".—\n",
|
|
"o\n",
|
|
"mnie\n",
|
|
"pan\n",
|
|
"na\n",
|
|
"co\n",
|
|
"dzień\n",
|
|
"na\n",
|
|
"głowie\n",
|
|
".—\n",
|
|
"co\n",
|
|
".\n",
|
|
"nie\n",
|
|
"jest\n",
|
|
"ci\n",
|
|
".—\n",
|
|
"pan\n",
|
|
".\n",
|
|
"nie\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"#dev0 pred\n",
|
|
"\n",
|
|
"with open(\"dev-0/in.tsv\", \"r\", encoding='utf-8') as dev_path:\n",
|
|
" nr_of_dev_lines = len(dev_path.readlines())\n",
|
|
"\n",
|
|
"with open(\"dev-0/out.tsv\", \"w\", encoding='utf-8') as out_dev_file:\n",
|
|
" for i in range(nr_of_dev_lines):\n",
|
|
" preds= lm(ids)\n",
|
|
" candidates = torch.topk(torch.softmax(preds,1),candidates_number)[1][0].cpu().numpy()\n",
|
|
" candidate = 15001\n",
|
|
" while candidate > 15000:\n",
|
|
" candidate = candidates[np.random.randint(candidates_number)]\n",
|
|
" print(vocab_itos[candidate])\n",
|
|
" ids = torch.cat((ids, torch.tensor([[candidate]], device = device)), 1)\n",
|
|
" out_dev_file.write(vocab_itos[candidate] + '\\n')\n"
|
|
]
|
|
},
|
|
{
|
|
"cell_type": "code",
|
|
"execution_count": 24,
|
|
"metadata": {},
|
|
"outputs": [
|
|
{
|
|
"name": "stdout",
|
|
"output_type": "stream",
|
|
"text": [
|
|
",\n",
|
|
"a\n",
|
|
"ja\n",
|
|
".\n",
|
|
"na\n",
|
|
"co\n",
|
|
"na\n",
|
|
"mnie\n",
|
|
"i\n",
|
|
"kto\n",
|
|
",\n",
|
|
"ale\n",
|
|
"nawet\n",
|
|
"na\n",
|
|
"mnie\n",
|
|
"!\n",
|
|
"co\n",
|
|
"ja\n",
|
|
".—\n",
|
|
"już\n",
|
|
"?\n",
|
|
"!\n",
|
|
")\n",
|
|
"i\n",
|
|
"pan\n",
|
|
"na\n",
|
|
"myśl\n",
|
|
";\n",
|
|
",\n",
|
|
"a\n",
|
|
"nawet\n",
|
|
"nie\n",
|
|
",\n",
|
|
"jak\n",
|
|
"pan\n",
|
|
"na\n",
|
|
"mnie\n",
|
|
"na\n",
|
|
",\n",
|
|
"i\n",
|
|
"o\n",
|
|
"co\n",
|
|
"ja\n",
|
|
"nie\n",
|
|
"chcę\n",
|
|
".—\n",
|
|
",\n",
|
|
"nie\n",
|
|
"mam\n",
|
|
"?\n",
|
|
"?\n",
|
|
"nie\n",
|
|
".\n",
|
|
"pani\n",
|
|
"nie\n",
|
|
"jest\n",
|
|
"na\n",
|
|
"co\n",
|
|
"nie\n",
|
|
"może\n",
|
|
"i\n",
|
|
"cóż\n",
|
|
"nie\n",
|
|
".\n",
|
|
"a\n",
|
|
"jeżeli\n",
|
|
"jak\n",
|
|
"ona\n",
|
|
"!\n",
|
|
"na\n",
|
|
"dole\n",
|
|
".\n",
|
|
"nie\n",
|
|
"był\n",
|
|
"pan\n",
|
|
".\n",
|
|
"nie\n",
|
|
"jest\n",
|
|
"jeszcze\n",
|
|
"jak\n",
|
|
"pani\n",
|
|
"?\n",
|
|
"i\n",
|
|
"o\n",
|
|
"?\n",
|
|
"po\n",
|
|
"?\n",
|
|
"po\n",
|
|
"co\n",
|
|
"dzień\n",
|
|
"?\n",
|
|
"na\n",
|
|
",\n",
|
|
"co\n",
|
|
"pan\n",
|
|
"do\n",
|
|
"niego\n",
|
|
"na\n",
|
|
"głowie\n",
|
|
".—\n",
|
|
".\n",
|
|
"nie\n",
|
|
"był\n",
|
|
".\n",
|
|
"na\n",
|
|
"myśl\n",
|
|
";\n",
|
|
"i\n",
|
|
"ja\n",
|
|
".\n",
|
|
"na\n",
|
|
"lewo\n",
|
|
";\n",
|
|
"była\n",
|
|
"go\n",
|
|
",\n",
|
|
"na\n",
|
|
"jej\n",
|
|
".—\n",
|
|
"o\n",
|
|
"!\n",
|
|
"?\n",
|
|
"na\n",
|
|
"co\n",
|
|
"!\n",
|
|
")\n",
|
|
"do\n",
|
|
"głowy\n",
|
|
".\n",
|
|
"i\n",
|
|
"nawet\n",
|
|
"do\n",
|
|
"niej\n",
|
|
".\n",
|
|
"nie\n",
|
|
"był\n",
|
|
";\n",
|
|
"o\n",
|
|
"ile\n",
|
|
"o\n",
|
|
"jego\n",
|
|
"o\n",
|
|
";\n",
|
|
"ale\n",
|
|
"pan\n",
|
|
"ignacy\n",
|
|
".—\n",
|
|
"nie\n",
|
|
"ma\n",
|
|
"pan\n",
|
|
"do\n",
|
|
".\n",
|
|
"ja\n",
|
|
"do\n",
|
|
"mego\n",
|
|
"i\n",
|
|
"nie\n",
|
|
"będzie\n",
|
|
"o\n",
|
|
"mnie\n",
|
|
"i\n",
|
|
"już\n",
|
|
".\n",
|
|
"o\n",
|
|
"co\n",
|
|
"pan\n",
|
|
"ignacy\n",
|
|
"?\n",
|
|
"na\n",
|
|
"którym\n",
|
|
",\n",
|
|
"kiedy\n",
|
|
"go\n",
|
|
"na\n",
|
|
"jej\n",
|
|
";\n",
|
|
"ale\n",
|
|
"co\n",
|
|
",\n",
|
|
"a\n",
|
|
"co\n",
|
|
"pan\n",
|
|
"?\n",
|
|
"i\n",
|
|
"kto\n",
|
|
"mu\n",
|
|
"pan\n",
|
|
",\n",
|
|
"co\n",
|
|
"?\n",
|
|
"o\n",
|
|
",\n",
|
|
"i\n",
|
|
"kto\n",
|
|
"by\n",
|
|
"mnie\n",
|
|
"do\n",
|
|
"głowy\n",
|
|
".—\n",
|
|
"a\n",
|
|
"!\n",
|
|
"nawet\n",
|
|
"o\n",
|
|
"niej\n",
|
|
"na\n",
|
|
"myśl\n",
|
|
"?\n",
|
|
"i\n",
|
|
"już\n",
|
|
"do\n",
|
|
".\n",
|
|
"nie\n",
|
|
"na\n",
|
|
"mnie\n",
|
|
"nie\n",
|
|
"mam\n",
|
|
".\n",
|
|
"była\n",
|
|
"już\n",
|
|
".\n",
|
|
"(\n",
|
|
",\n",
|
|
"nie\n",
|
|
"!\n",
|
|
",\n",
|
|
"jak\n",
|
|
"on\n",
|
|
"mnie\n",
|
|
".—\n",
|
|
"pan\n",
|
|
".\n",
|
|
"(\n",
|
|
"może\n",
|
|
"na\n",
|
|
"nie\n",
|
|
"było\n",
|
|
"i\n",
|
|
",\n",
|
|
"który\n",
|
|
"by\n",
|
|
"mu\n",
|
|
"nie\n",
|
|
".\n",
|
|
"i\n",
|
|
"dopiero\n",
|
|
".\n",
|
|
"a\n",
|
|
",\n",
|
|
"jak\n",
|
|
"ja\n",
|
|
",\n",
|
|
"na\n",
|
|
"którym\n",
|
|
"?\n",
|
|
"a\n",
|
|
"jeżeli\n",
|
|
"jest\n",
|
|
"bardzo\n",
|
|
"?\n",
|
|
"!\n",
|
|
",\n",
|
|
"bo\n",
|
|
"już\n",
|
|
".\n",
|
|
"nie\n",
|
|
"chcę\n",
|
|
"go\n",
|
|
"do\n",
|
|
"paryża\n",
|
|
".—\n",
|
|
"co\n",
|
|
"dzień\n",
|
|
"pan\n",
|
|
"nie\n",
|
|
".\n",
|
|
"?\n",
|
|
"co\n",
|
|
"na\n",
|
|
"myśl\n",
|
|
"!\n",
|
|
",\n",
|
|
"a\n",
|
|
"może\n",
|
|
"jeszcze\n",
|
|
"na\n",
|
|
"niego\n",
|
|
",\n",
|
|
"nie\n",
|
|
"ma\n",
|
|
",\n",
|
|
"a\n",
|
|
"pan\n",
|
|
"nie\n",
|
|
"będzie\n",
|
|
".—\n",
|
|
"nic\n",
|
|
"mnie\n",
|
|
"pan\n",
|
|
".\n",
|
|
"*\n",
|
|
".\n",
|
|
"ja\n",
|
|
"nie\n",
|
|
",\n",
|
|
"pani\n",
|
|
"dobrodziejko\n",
|
|
".—\n",
|
|
"i\n",
|
|
"cóż\n",
|
|
".\n",
|
|
"pan\n",
|
|
"nie\n",
|
|
"jadł\n",
|
|
"na\n",
|
|
"nich\n",
|
|
"!\n",
|
|
";\n",
|
|
"na\n",
|
|
"lewo\n",
|
|
"na\n",
|
|
"mnie\n",
|
|
"i\n",
|
|
"na\n",
|
|
"nogi\n",
|
|
"?\n",
|
|
".—\n",
|
|
"nie\n",
|
|
"chcę\n",
|
|
"?\n",
|
|
",\n",
|
|
"co\n",
|
|
"by\n",
|
|
"?\n",
|
|
"!\n",
|
|
"o\n",
|
|
"?\n",
|
|
"po\n",
|
|
"i\n",
|
|
"nawet\n",
|
|
",\n",
|
|
"jak\n",
|
|
"ja\n",
|
|
".\n",
|
|
"ale\n",
|
|
"o\n",
|
|
"jej\n",
|
|
"!\n",
|
|
",\n",
|
|
"jak\n",
|
|
"ja\n",
|
|
"już\n",
|
|
"nic\n",
|
|
"!\n",
|
|
")\n",
|
|
"!\n",
|
|
"cha\n",
|
|
",\n",
|
|
"ale\n",
|
|
"nawet\n",
|
|
"do\n",
|
|
"głowy\n",
|
|
"na\n",
|
|
",\n",
|
|
"nie\n",
|
|
"mógł\n",
|
|
"nawet\n",
|
|
"nie\n",
|
|
"mógł\n",
|
|
"do\n",
|
|
"niego\n",
|
|
"nie\n",
|
|
"na\n",
|
|
"mnie\n",
|
|
"?\n",
|
|
")\n",
|
|
",\n",
|
|
"ale\n",
|
|
"jeszcze\n",
|
|
".\n",
|
|
"po\n",
|
|
".\n",
|
|
"o\n",
|
|
"mnie\n",
|
|
"na\n",
|
|
"jego\n",
|
|
"na\n",
|
|
"myśl\n",
|
|
"i\n",
|
|
"nawet\n",
|
|
"na\n",
|
|
"lewo\n",
|
|
"na\n",
|
|
"głowie\n",
|
|
"na\n",
|
|
"górę\n",
|
|
"i\n",
|
|
"po\n",
|
|
"otworzeniu\n",
|
|
";\n",
|
|
"ale\n",
|
|
"co\n",
|
|
"do\n",
|
|
"na\n",
|
|
"jego\n",
|
|
".—\n",
|
|
"a\n",
|
|
"pan\n",
|
|
"i\n",
|
|
"co\n",
|
|
".\n",
|
|
"jest\n",
|
|
"pan\n",
|
|
"ignacy\n",
|
|
"do\n",
|
|
"paryża\n",
|
|
"nie\n",
|
|
"mam\n",
|
|
".\n",
|
|
"a\n",
|
|
"jeżeli\n",
|
|
"na\n",
|
|
"jej\n",
|
|
"?\n",
|
|
".\n",
|
|
"o\n",
|
|
"nie\n",
|
|
"i\n",
|
|
"nie\n",
|
|
".\n",
|
|
"o\n",
|
|
"jego\n",
|
|
"po\n",
|
|
"pokoju\n",
|
|
",\n",
|
|
"jak\n",
|
|
"ja\n",
|
|
"już\n",
|
|
":\n",
|
|
"od\n",
|
|
"na\n",
|
|
"do\n",
|
|
";\n",
|
|
"ale\n",
|
|
"nawet\n",
|
|
"o\n",
|
|
"niej\n",
|
|
"nie\n",
|
|
"jest\n",
|
|
",\n",
|
|
"ale\n",
|
|
",\n",
|
|
"jak\n",
|
|
",\n",
|
|
"na\n",
|
|
"jej\n",
|
|
".\n",
|
|
"nie\n",
|
|
"był\n",
|
|
"ani\n",
|
|
"ani\n",
|
|
"do\n",
|
|
",\n",
|
|
"a\n",
|
|
"na\n",
|
|
"nią\n",
|
|
":\n",
|
|
"nawet\n",
|
|
"co\n",
|
|
"nie\n",
|
|
".\n",
|
|
"na\n"
|
|
]
|
|
}
|
|
],
|
|
"source": [
|
|
"#testA pred\n",
|
|
"\n",
|
|
"with open(\"test-A/in.tsv\", \"r\", encoding='utf-8') as test_a_path:\n",
|
|
" nr_of_test_a_lines = len(test_a_path.readlines())\n",
|
|
"with open(\"test-A/out.tsv\", \"w\", encoding='utf-8') as out_test_file:\n",
|
|
" for i in range(nr_of_dev_lines):\n",
|
|
" preds= lm(ids)\n",
|
|
" candidates = torch.topk(torch.softmax(preds,1),candidates_number)[1][0].cpu().numpy()\n",
|
|
" candidate = 15001\n",
|
|
" while candidate > 15000:\n",
|
|
" candidate = candidates[np.random.randint(candidates_number)]\n",
|
|
" print(vocab_itos[candidate])\n",
|
|
" ids = torch.cat((ids, torch.tensor([[candidate]], device = device)), 1)\n",
|
|
" out_test_file.write(vocab_itos[candidate] + '\\n')"
|
|
]
|
|
}
|
|
],
|
|
"metadata": {
|
|
"kernelspec": {
|
|
"display_name": "Python 3",
|
|
"language": "python",
|
|
"name": "python3"
|
|
},
|
|
"language_info": {
|
|
"codemirror_mode": {
|
|
"name": "ipython",
|
|
"version": 3
|
|
},
|
|
"file_extension": ".py",
|
|
"mimetype": "text/x-python",
|
|
"name": "python",
|
|
"nbconvert_exporter": "python",
|
|
"pygments_lexer": "ipython3",
|
|
"version": "3.8.3"
|
|
}
|
|
},
|
|
"nbformat": 4,
|
|
"nbformat_minor": 4
|
|
}
|