challenging-america-word-ga.../run.ipynb
2022-05-28 15:36:41 +02:00

810 lines
42 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "03de852a",
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"import regex as re\n",
"import csv\n",
"import torch\n",
"from torch import nn\n",
"from gensim.models import Word2Vec\n",
"from nltk.tokenize import word_tokenize"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "73497953",
"metadata": {},
"outputs": [],
"source": [
"torch.cuda.empty_cache()\n",
"device = 'cuda' if torch.cuda.is_available() else 'cpu'"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "4227ef55",
"metadata": {},
"outputs": [],
"source": [
"def clean_text(text):\n",
" text = text.lower().replace('-\\\\\\\\\\\\\\\\n', '').replace('\\\\\\\\\\\\\\\\n', ' ')\n",
" text = re.sub(r'\\p{P}', '', text)\n",
" text = text.replace(\"'t\", \" not\").replace(\"'s\", \" is\").replace(\"'ll\", \" will\").replace(\"'m\", \" am\").replace(\"'ve\", \" have\")\n",
"\n",
" return text"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "758cf94a",
"metadata": {},
"outputs": [],
"source": [
"train_data = pd.read_csv('train/in.tsv.xz', sep='\\t', on_bad_lines='skip', header=None, quoting=csv.QUOTE_NONE)\n",
"train_labels = pd.read_csv('train/expected.tsv', sep='\\t', on_bad_lines='skip', header=None, quoting=csv.QUOTE_NONE)\n",
"\n",
"train_data = train_data[[6, 7]]\n",
"train_data = pd.concat([train_data, train_labels], axis=1)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "384922c7",
"metadata": {},
"outputs": [],
"source": [
"class TrainCorpus:\n",
" def __init__(self, data):\n",
" self.data = data\n",
" \n",
" def __iter__(self):\n",
" for _, row in self.data.iterrows():\n",
" text = str(row[6]) + str(row[0]) + str(row[7])\n",
" text = clean_text(text)\n",
" yield word_tokenize(text)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "3da5f19b",
"metadata": {},
"outputs": [],
"source": [
"train_sentences = TrainCorpus(train_data.head(100000))\n",
"w2v_model = Word2Vec(vector_size=100, min_count=10)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "183d43be",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"97122\n"
]
}
],
"source": [
"w2v_model.build_vocab(corpus_iterable=train_sentences)\n",
"\n",
"key_to_index = w2v_model.wv.key_to_index\n",
"index_to_key = w2v_model.wv.index_to_key\n",
"\n",
"index_to_key.append('<unk>')\n",
"key_to_index['<unk>'] = len(index_to_key) - 1\n",
"\n",
"vocab_size = len(index_to_key)\n",
"print(vocab_size)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "e63dd9fe",
"metadata": {},
"outputs": [],
"source": [
"class TrainDataset(torch.utils.data.IterableDataset):\n",
" def __init__(self, data, index_to_key, key_to_index):\n",
" self.data = data\n",
" self.index_to_key = index_to_key\n",
" self.key_to_index = key_to_index\n",
" self.vocab_size = len(key_to_index)\n",
"\n",
" def __iter__(self):\n",
" for _, row in self.data.iterrows():\n",
" text = str(row[6]) + str(row[0]) + str(row[7])\n",
" text = clean_text(text)\n",
" tokens = word_tokenize(text)\n",
" for i in range(5, len(tokens), 1):\n",
" input_context = tokens[i-5:i]\n",
" target_context = tokens[i-4:i+1]\n",
" #gap_word = tokens[i]\n",
" \n",
" input_embed = [self.key_to_index[word] if word in self.key_to_index else self.key_to_index['<unk>'] for word in input_context]\n",
" target_embed = [self.key_to_index[word] if word in self.key_to_index else self.key_to_index['<unk>'] for word in target_context]\n",
" #word_index = self.key_to_index[gap_word] if gap_word in self.key_to_index else self.key_to_index['<unk>']\n",
" #word_embed = np.concatenate([np.zeros(word_index), np.ones(1), np.zeros(vocab_size - word_index - 1)])\n",
" \n",
" yield np.asarray(input_embed, dtype=np.int64), np.asarray(target_embed, dtype=np.int64)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "7c60ddc1",
"metadata": {},
"outputs": [],
"source": [
"class Model(nn.Module):\n",
" def __init__(self, embed_size, vocab_size):\n",
" super(Model, self).__init__()\n",
" self.embed_size = embed_size\n",
" self.vocab_size = vocab_size\n",
" self.gru_size = 128\n",
" self.num_layers = 2\n",
" \n",
" self.embed = nn.Embedding(num_embeddings=vocab_size, embedding_dim=self.embed_size)\n",
" self.gru = nn.GRU(input_size=self.embed_size, hidden_size=self.gru_size, num_layers=self.num_layers, dropout=0.2)\n",
" self.fc = nn.Linear(self.gru_size, vocab_size)\n",
"\n",
" def forward(self, x, prev_state = None):\n",
" embed = self.embed(x)\n",
" output, state = self.gru(embed, prev_state)\n",
" logits = self.fc(output)\n",
" probs = torch.softmax(logits, dim=1)\n",
" return logits, state\n",
"\n",
" def init_state(self, sequence_length):\n",
" zeros = torch.zeros(self.num_layers, sequence_length, self.gru_size).to(device)\n",
" return (zeros, zeros)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "1c7b8fab",
"metadata": {},
"outputs": [],
"source": [
"from torch.utils.data import DataLoader\n",
"from torch.optim import Adam\n",
"\n",
"def train(dataset, model, max_epochs, batch_size):\n",
" model.train()\n",
"\n",
" dataloader = DataLoader(dataset, batch_size=batch_size)\n",
" criterion = nn.CrossEntropyLoss()\n",
" optimizer = Adam(model.parameters(), lr=0.001)\n",
"\n",
" for epoch in range(max_epochs):\n",
" for batch, (x, y) in enumerate(dataloader):\n",
" optimizer.zero_grad()\n",
" \n",
" x = x.to(device)\n",
" y = y.to(device)\n",
" \n",
" y_pred, state_h = model(x)\n",
" loss = criterion(y_pred.transpose(1, 2), y)\n",
"\n",
" loss.backward()\n",
" optimizer.step()\n",
" \n",
" if batch % 1000 == 0:\n",
" print(f'epoch: {epoch}, update in batch {batch}/???, loss: {loss.item()}')"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "3531d21d",
"metadata": {},
"outputs": [],
"source": [
"train_dataset = TrainDataset(train_data.head(100000), index_to_key, key_to_index)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "f72f1f6d",
"metadata": {},
"outputs": [],
"source": [
"model = Model(100, vocab_size).to(device)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "d608d9fe",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch: 0, update in batch 0/???, loss: 11.47425365447998\n",
"epoch: 0, update in batch 1000/???, loss: 7.042205810546875\n",
"epoch: 0, update in batch 2000/???, loss: 7.235440731048584\n",
"epoch: 0, update in batch 3000/???, loss: 7.251269340515137\n",
"epoch: 0, update in batch 4000/???, loss: 6.944191932678223\n",
"epoch: 0, update in batch 5000/???, loss: 6.263372421264648\n",
"epoch: 0, update in batch 6000/???, loss: 6.181947231292725\n",
"epoch: 0, update in batch 7000/???, loss: 6.508013725280762\n",
"epoch: 0, update in batch 8000/???, loss: 6.658236026763916\n",
"epoch: 0, update in batch 9000/???, loss: 6.536279201507568\n",
"epoch: 0, update in batch 10000/???, loss: 6.802626609802246\n",
"epoch: 0, update in batch 11000/???, loss: 6.961029052734375\n",
"epoch: 0, update in batch 12000/???, loss: 7.713824272155762\n",
"epoch: 0, update in batch 13000/???, loss: 8.100411415100098\n",
"epoch: 0, update in batch 14000/???, loss: 6.457145690917969\n",
"epoch: 0, update in batch 15000/???, loss: 6.850286960601807\n",
"epoch: 0, update in batch 16000/???, loss: 6.794063568115234\n",
"epoch: 0, update in batch 17000/???, loss: 6.311314582824707\n",
"epoch: 0, update in batch 18000/???, loss: 6.611917018890381\n",
"epoch: 0, update in batch 19000/???, loss: 5.679810523986816\n",
"epoch: 0, update in batch 20000/???, loss: 7.110655307769775\n",
"epoch: 0, update in batch 21000/???, loss: 7.170722961425781\n",
"epoch: 0, update in batch 22000/???, loss: 6.584908485412598\n",
"epoch: 0, update in batch 23000/???, loss: 7.224095344543457\n",
"epoch: 0, update in batch 24000/???, loss: 5.827445983886719\n",
"epoch: 0, update in batch 25000/???, loss: 6.444586753845215\n",
"epoch: 0, update in batch 26000/???, loss: 6.149054527282715\n",
"epoch: 0, update in batch 27000/???, loss: 6.259871482849121\n",
"epoch: 0, update in batch 28000/???, loss: 5.789839744567871\n",
"epoch: 0, update in batch 29000/???, loss: 7.025563716888428\n",
"epoch: 0, update in batch 30000/???, loss: 7.265492916107178\n",
"epoch: 0, update in batch 31000/???, loss: 4.921586036682129\n",
"epoch: 0, update in batch 32000/???, loss: 6.467754364013672\n",
"epoch: 0, update in batch 33000/???, loss: 7.393715858459473\n",
"epoch: 0, update in batch 34000/???, loss: 6.9696760177612305\n",
"epoch: 0, update in batch 35000/???, loss: 7.276318550109863\n",
"epoch: 0, update in batch 36000/???, loss: 7.011231899261475\n",
"epoch: 0, update in batch 37000/???, loss: 7.029260158538818\n",
"epoch: 0, update in batch 38000/???, loss: 6.723126411437988\n",
"epoch: 0, update in batch 39000/???, loss: 6.828773498535156\n",
"epoch: 0, update in batch 40000/???, loss: 6.069770336151123\n",
"epoch: 0, update in batch 41000/???, loss: 6.651298522949219\n",
"epoch: 0, update in batch 42000/???, loss: 7.455380916595459\n",
"epoch: 0, update in batch 43000/???, loss: 5.594773769378662\n",
"epoch: 0, update in batch 44000/???, loss: 6.102865219116211\n",
"epoch: 0, update in batch 45000/???, loss: 6.04202127456665\n",
"epoch: 0, update in batch 46000/???, loss: 6.472177982330322\n",
"epoch: 0, update in batch 47000/???, loss: 5.870923042297363\n",
"epoch: 0, update in batch 48000/???, loss: 6.286317348480225\n",
"epoch: 0, update in batch 49000/???, loss: 7.157052516937256\n",
"epoch: 0, update in batch 50000/???, loss: 5.888463020324707\n",
"epoch: 0, update in batch 51000/???, loss: 5.609915733337402\n",
"epoch: 0, update in batch 52000/???, loss: 6.565190315246582\n",
"epoch: 0, update in batch 53000/???, loss: 6.4924468994140625\n",
"epoch: 0, update in batch 54000/???, loss: 6.856420040130615\n",
"epoch: 0, update in batch 55000/???, loss: 7.389428615570068\n",
"epoch: 0, update in batch 56000/???, loss: 5.927685260772705\n",
"epoch: 0, update in batch 57000/???, loss: 7.4227423667907715\n",
"epoch: 0, update in batch 58000/???, loss: 6.46466064453125\n",
"epoch: 0, update in batch 59000/???, loss: 6.586294651031494\n",
"epoch: 0, update in batch 60000/???, loss: 5.797083854675293\n",
"epoch: 0, update in batch 61000/???, loss: 4.825878143310547\n",
"epoch: 0, update in batch 62000/???, loss: 6.911933898925781\n",
"epoch: 0, update in batch 63000/???, loss: 7.684759616851807\n",
"epoch: 0, update in batch 64000/???, loss: 5.716580390930176\n",
"epoch: 0, update in batch 65000/???, loss: 6.1738667488098145\n",
"epoch: 0, update in batch 66000/???, loss: 6.219714164733887\n",
"epoch: 0, update in batch 67000/???, loss: 5.4024128913879395\n",
"epoch: 0, update in batch 68000/???, loss: 6.912312984466553\n",
"epoch: 0, update in batch 69000/???, loss: 6.703289031982422\n",
"epoch: 0, update in batch 70000/???, loss: 7.375630855560303\n",
"epoch: 0, update in batch 71000/???, loss: 5.757082462310791\n",
"epoch: 0, update in batch 72000/???, loss: 5.992405414581299\n",
"epoch: 0, update in batch 73000/???, loss: 6.706838130950928\n",
"epoch: 0, update in batch 74000/???, loss: 7.376870155334473\n",
"epoch: 0, update in batch 75000/???, loss: 6.676860809326172\n",
"epoch: 0, update in batch 76000/???, loss: 5.904101848602295\n",
"epoch: 0, update in batch 77000/???, loss: 6.776932716369629\n",
"epoch: 0, update in batch 78000/???, loss: 5.682181358337402\n",
"epoch: 0, update in batch 79000/???, loss: 6.211178302764893\n",
"epoch: 0, update in batch 80000/???, loss: 6.366950035095215\n",
"epoch: 0, update in batch 81000/???, loss: 5.25206184387207\n",
"epoch: 0, update in batch 82000/???, loss: 6.30997371673584\n",
"epoch: 0, update in batch 83000/???, loss: 6.351908206939697\n",
"epoch: 0, update in batch 84000/???, loss: 7.659114837646484\n",
"epoch: 0, update in batch 85000/???, loss: 6.5041704177856445\n",
"epoch: 0, update in batch 86000/???, loss: 6.770291328430176\n",
"epoch: 0, update in batch 87000/???, loss: 6.530011177062988\n",
"epoch: 0, update in batch 88000/???, loss: 6.317249298095703\n",
"epoch: 0, update in batch 89000/???, loss: 6.191559314727783\n",
"epoch: 0, update in batch 90000/???, loss: 5.79150390625\n",
"epoch: 0, update in batch 91000/???, loss: 6.356796741485596\n",
"epoch: 0, update in batch 92000/???, loss: 7.3577141761779785\n",
"epoch: 0, update in batch 93000/???, loss: 6.529308319091797\n",
"epoch: 0, update in batch 94000/???, loss: 7.740485191345215\n",
"epoch: 0, update in batch 95000/???, loss: 6.348109245300293\n",
"epoch: 0, update in batch 96000/???, loss: 6.032902717590332\n",
"epoch: 0, update in batch 97000/???, loss: 4.505112648010254\n",
"epoch: 0, update in batch 98000/???, loss: 6.946290493011475\n",
"epoch: 0, update in batch 99000/???, loss: 6.237973213195801\n",
"epoch: 0, update in batch 100000/???, loss: 6.963421821594238\n",
"epoch: 0, update in batch 101000/???, loss: 5.309017181396484\n",
"epoch: 0, update in batch 102000/???, loss: 6.242384910583496\n",
"epoch: 0, update in batch 103000/???, loss: 6.8203558921813965\n",
"epoch: 0, update in batch 104000/???, loss: 6.242025852203369\n",
"epoch: 0, update in batch 105000/???, loss: 6.765100002288818\n",
"epoch: 0, update in batch 106000/???, loss: 6.8838043212890625\n",
"epoch: 0, update in batch 107000/???, loss: 6.856662750244141\n",
"epoch: 0, update in batch 108000/???, loss: 6.379549503326416\n",
"epoch: 0, update in batch 109000/???, loss: 6.797707557678223\n",
"epoch: 0, update in batch 110000/???, loss: 7.2699689865112305\n",
"epoch: 0, update in batch 111000/???, loss: 7.040057182312012\n",
"epoch: 0, update in batch 112000/???, loss: 6.7861223220825195\n",
"epoch: 0, update in batch 113000/???, loss: 6.064489364624023\n",
"epoch: 0, update in batch 114000/???, loss: 6.095967769622803\n",
"epoch: 0, update in batch 115000/???, loss: 5.757347106933594\n",
"epoch: 0, update in batch 116000/???, loss: 6.529908657073975\n",
"epoch: 0, update in batch 117000/???, loss: 6.030801296234131\n",
"epoch: 0, update in batch 118000/???, loss: 6.179767608642578\n",
"epoch: 0, update in batch 119000/???, loss: 5.436234474182129\n",
"epoch: 0, update in batch 120000/???, loss: 7.342876434326172\n",
"epoch: 0, update in batch 121000/???, loss: 6.862719535827637\n",
"epoch: 0, update in batch 122000/???, loss: 6.491606712341309\n",
"epoch: 0, update in batch 123000/???, loss: 7.195406436920166\n",
"epoch: 0, update in batch 124000/???, loss: 5.481313228607178\n",
"epoch: 0, update in batch 125000/???, loss: 7.963885307312012\n",
"epoch: 0, update in batch 126000/???, loss: 6.479039669036865\n",
"epoch: 0, update in batch 127000/???, loss: 7.037934303283691\n",
"epoch: 0, update in batch 128000/???, loss: 5.903053283691406\n",
"epoch: 0, update in batch 129000/???, loss: 6.815878391265869\n",
"epoch: 0, update in batch 130000/???, loss: 6.497969150543213\n",
"epoch: 0, update in batch 131000/???, loss: 5.623625755310059\n",
"epoch: 0, update in batch 132000/???, loss: 7.118441104888916\n",
"epoch: 0, update in batch 133000/???, loss: 5.964345455169678\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch: 0, update in batch 134000/???, loss: 6.112139701843262\n",
"epoch: 0, update in batch 135000/???, loss: 6.5865373611450195\n",
"epoch: 0, update in batch 136000/???, loss: 7.498536109924316\n",
"epoch: 0, update in batch 137000/???, loss: 7.124758243560791\n",
"epoch: 0, update in batch 138000/???, loss: 6.871796607971191\n",
"epoch: 0, update in batch 139000/???, loss: 5.8565263748168945\n",
"epoch: 0, update in batch 140000/???, loss: 5.723143577575684\n",
"epoch: 0, update in batch 141000/???, loss: 5.601426124572754\n",
"epoch: 0, update in batch 142000/???, loss: 5.495566368103027\n",
"epoch: 0, update in batch 143000/???, loss: 6.936192989349365\n",
"epoch: 0, update in batch 144000/???, loss: 6.1843671798706055\n",
"epoch: 0, update in batch 145000/???, loss: 6.886034965515137\n",
"epoch: 0, update in batch 146000/???, loss: 6.655320167541504\n",
"epoch: 0, update in batch 147000/???, loss: 6.46828556060791\n",
"epoch: 0, update in batch 148000/???, loss: 5.607057571411133\n",
"epoch: 0, update in batch 149000/???, loss: 7.182212829589844\n",
"epoch: 0, update in batch 150000/???, loss: 7.241323947906494\n",
"epoch: 0, update in batch 151000/???, loss: 7.308540344238281\n",
"epoch: 0, update in batch 152000/???, loss: 5.267911434173584\n",
"epoch: 0, update in batch 153000/???, loss: 5.895949363708496\n",
"epoch: 0, update in batch 154000/???, loss: 6.629178524017334\n",
"epoch: 0, update in batch 155000/???, loss: 4.9156012535095215\n",
"epoch: 0, update in batch 156000/???, loss: 7.181819915771484\n",
"epoch: 0, update in batch 157000/???, loss: 7.438391208648682\n",
"epoch: 0, update in batch 158000/???, loss: 6.406006813049316\n",
"epoch: 0, update in batch 159000/???, loss: 6.486207008361816\n",
"epoch: 0, update in batch 160000/???, loss: 7.041951656341553\n",
"epoch: 0, update in batch 161000/???, loss: 5.310082912445068\n",
"epoch: 0, update in batch 162000/???, loss: 6.9074387550354\n",
"epoch: 0, update in batch 163000/???, loss: 6.644919395446777\n",
"epoch: 0, update in batch 164000/???, loss: 6.011733055114746\n",
"epoch: 0, update in batch 165000/???, loss: 6.494180202484131\n",
"epoch: 0, update in batch 166000/???, loss: 5.390150547027588\n",
"epoch: 0, update in batch 167000/???, loss: 6.627297401428223\n",
"epoch: 0, update in batch 168000/???, loss: 6.9020209312438965\n",
"epoch: 0, update in batch 169000/???, loss: 7.317750453948975\n",
"epoch: 0, update in batch 170000/???, loss: 5.69993782043457\n",
"epoch: 0, update in batch 171000/???, loss: 6.658817291259766\n",
"epoch: 0, update in batch 172000/???, loss: 6.422945976257324\n",
"epoch: 0, update in batch 173000/???, loss: 5.822269439697266\n",
"epoch: 0, update in batch 174000/???, loss: 6.513391017913818\n",
"epoch: 0, update in batch 175000/???, loss: 5.886590957641602\n",
"epoch: 0, update in batch 176000/???, loss: 7.119387149810791\n",
"epoch: 0, update in batch 177000/???, loss: 6.933981418609619\n",
"epoch: 0, update in batch 178000/???, loss: 6.678143501281738\n",
"epoch: 0, update in batch 179000/???, loss: 6.890423774719238\n",
"epoch: 0, update in batch 180000/???, loss: 6.932961940765381\n",
"epoch: 0, update in batch 181000/???, loss: 6.650975704193115\n",
"epoch: 0, update in batch 182000/???, loss: 6.732748985290527\n",
"epoch: 0, update in batch 183000/???, loss: 6.064764976501465\n",
"epoch: 0, update in batch 184000/???, loss: 5.282295227050781\n",
"epoch: 0, update in batch 185000/???, loss: 6.569302558898926\n",
"epoch: 0, update in batch 186000/???, loss: 5.800485610961914\n",
"epoch: 0, update in batch 187000/???, loss: 6.175991058349609\n",
"epoch: 0, update in batch 188000/???, loss: 5.405575752258301\n",
"epoch: 0, update in batch 189000/???, loss: 6.191354751586914\n",
"epoch: 0, update in batch 190000/???, loss: 6.156663417816162\n",
"epoch: 0, update in batch 191000/???, loss: 6.937534332275391\n",
"epoch: 0, update in batch 192000/???, loss: 6.562686920166016\n",
"epoch: 0, update in batch 193000/???, loss: 6.639985084533691\n",
"epoch: 0, update in batch 194000/???, loss: 7.285438537597656\n",
"epoch: 0, update in batch 195000/???, loss: 6.528258323669434\n",
"epoch: 0, update in batch 196000/???, loss: 8.326434135437012\n",
"epoch: 0, update in batch 197000/???, loss: 6.781360626220703\n",
"epoch: 0, update in batch 198000/???, loss: 7.223299980163574\n",
"epoch: 0, update in batch 199000/???, loss: 6.411007881164551\n",
"epoch: 0, update in batch 200000/???, loss: 5.885635852813721\n",
"epoch: 0, update in batch 201000/???, loss: 5.706809043884277\n",
"epoch: 0, update in batch 202000/???, loss: 6.230217933654785\n",
"epoch: 0, update in batch 203000/???, loss: 7.056562900543213\n",
"epoch: 0, update in batch 204000/???, loss: 7.2273077964782715\n",
"epoch: 0, update in batch 205000/???, loss: 6.342462539672852\n",
"epoch: 0, update in batch 206000/???, loss: 6.556817054748535\n",
"epoch: 0, update in batch 207000/???, loss: 5.882349967956543\n",
"epoch: 0, update in batch 208000/???, loss: 6.755805015563965\n",
"epoch: 0, update in batch 209000/???, loss: 6.5045623779296875\n",
"epoch: 0, update in batch 210000/???, loss: 6.525590419769287\n",
"epoch: 0, update in batch 211000/???, loss: 6.49679708480835\n",
"epoch: 0, update in batch 212000/???, loss: 6.562323093414307\n",
"epoch: 0, update in batch 213000/???, loss: 5.227139472961426\n",
"epoch: 0, update in batch 214000/???, loss: 7.044825077056885\n",
"epoch: 0, update in batch 215000/???, loss: 6.002442359924316\n",
"epoch: 0, update in batch 216000/???, loss: 6.084803581237793\n",
"epoch: 0, update in batch 217000/???, loss: 7.425839900970459\n",
"epoch: 0, update in batch 218000/???, loss: 6.818853855133057\n",
"epoch: 0, update in batch 219000/???, loss: 7.0153374671936035\n",
"epoch: 0, update in batch 220000/???, loss: 6.219962120056152\n",
"epoch: 0, update in batch 221000/???, loss: 5.9975385665893555\n",
"epoch: 0, update in batch 222000/???, loss: 6.480047702789307\n",
"epoch: 0, update in batch 223000/???, loss: 6.405727386474609\n",
"epoch: 0, update in batch 224000/???, loss: 4.7763471603393555\n",
"epoch: 0, update in batch 225000/???, loss: 6.615710258483887\n",
"epoch: 0, update in batch 226000/???, loss: 6.385044574737549\n",
"epoch: 0, update in batch 227000/???, loss: 7.260453701019287\n",
"epoch: 0, update in batch 228000/???, loss: 6.9794135093688965\n",
"epoch: 0, update in batch 229000/???, loss: 6.235735893249512\n",
"epoch: 0, update in batch 230000/???, loss: 6.478426456451416\n",
"epoch: 0, update in batch 231000/???, loss: 6.181302547454834\n",
"epoch: 0, update in batch 232000/???, loss: 5.826043128967285\n",
"epoch: 0, update in batch 233000/???, loss: 5.9517011642456055\n",
"epoch: 0, update in batch 234000/???, loss: 8.0064697265625\n",
"epoch: 0, update in batch 235000/???, loss: 6.7822675704956055\n",
"epoch: 0, update in batch 236000/???, loss: 6.293349742889404\n",
"epoch: 0, update in batch 237000/???, loss: 6.442999362945557\n",
"epoch: 0, update in batch 238000/???, loss: 6.282561302185059\n",
"epoch: 0, update in batch 239000/???, loss: 7.166723728179932\n",
"epoch: 0, update in batch 240000/???, loss: 7.189905643463135\n",
"epoch: 0, update in batch 241000/???, loss: 8.462562561035156\n",
"epoch: 0, update in batch 242000/???, loss: 7.446291923522949\n",
"epoch: 0, update in batch 243000/???, loss: 6.382981777191162\n",
"epoch: 0, update in batch 244000/???, loss: 7.635994911193848\n",
"epoch: 0, update in batch 245000/???, loss: 6.635537147521973\n",
"epoch: 0, update in batch 246000/???, loss: 6.068560600280762\n",
"epoch: 0, update in batch 247000/???, loss: 6.193384170532227\n",
"epoch: 0, update in batch 248000/???, loss: 5.702363967895508\n",
"epoch: 0, update in batch 249000/???, loss: 6.09995174407959\n",
"epoch: 0, update in batch 250000/???, loss: 6.312221050262451\n",
"epoch: 0, update in batch 251000/???, loss: 5.853858470916748\n",
"epoch: 0, update in batch 252000/???, loss: 5.886989593505859\n",
"epoch: 0, update in batch 253000/???, loss: 5.801788330078125\n",
"epoch: 0, update in batch 254000/???, loss: 6.032407760620117\n",
"epoch: 0, update in batch 255000/???, loss: 7.480917453765869\n",
"epoch: 0, update in batch 256000/???, loss: 6.578718662261963\n",
"epoch: 0, update in batch 257000/???, loss: 6.344462871551514\n",
"epoch: 0, update in batch 258000/???, loss: 5.939858436584473\n",
"epoch: 0, update in batch 259000/???, loss: 5.181772232055664\n",
"epoch: 0, update in batch 260000/???, loss: 6.640598297119141\n",
"epoch: 0, update in batch 261000/???, loss: 7.189258575439453\n",
"epoch: 0, update in batch 262000/???, loss: 6.2269287109375\n",
"epoch: 0, update in batch 263000/???, loss: 5.8858795166015625\n",
"epoch: 0, update in batch 264000/???, loss: 6.333988666534424\n",
"epoch: 0, update in batch 265000/???, loss: 6.313681602478027\n",
"epoch: 0, update in batch 266000/???, loss: 5.485809803009033\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch: 0, update in batch 267000/???, loss: 6.250800609588623\n",
"epoch: 0, update in batch 268000/???, loss: 6.676806449890137\n",
"epoch: 0, update in batch 269000/???, loss: 5.6487932205200195\n",
"epoch: 0, update in batch 270000/???, loss: 6.648938179016113\n",
"epoch: 0, update in batch 271000/???, loss: 6.26931095123291\n",
"epoch: 0, update in batch 272000/???, loss: 5.343636512756348\n",
"epoch: 0, update in batch 273000/???, loss: 7.051453590393066\n",
"epoch: 0, update in batch 274000/???, loss: 4.578436851501465\n",
"epoch: 0, update in batch 275000/???, loss: 5.400996685028076\n",
"epoch: 0, update in batch 276000/???, loss: 6.129047870635986\n",
"epoch: 0, update in batch 277000/???, loss: 7.549851894378662\n",
"epoch: 0, update in batch 278000/???, loss: 6.093559265136719\n",
"epoch: 0, update in batch 279000/???, loss: 5.6921467781066895\n",
"epoch: 0, update in batch 280000/???, loss: 5.789463996887207\n",
"epoch: 0, update in batch 281000/???, loss: 5.681942939758301\n",
"epoch: 0, update in batch 282000/???, loss: 6.750497341156006\n",
"epoch: 0, update in batch 283000/???, loss: 5.960292339324951\n",
"epoch: 0, update in batch 284000/???, loss: 6.160388469696045\n",
"epoch: 0, update in batch 285000/???, loss: 7.137685298919678\n",
"epoch: 0, update in batch 286000/???, loss: 7.7431464195251465\n",
"epoch: 0, update in batch 287000/???, loss: 5.229738712310791\n",
"epoch: 0, update in batch 288000/???, loss: 6.654232025146484\n",
"epoch: 0, update in batch 289000/???, loss: 6.229329586029053\n",
"epoch: 0, update in batch 290000/???, loss: 7.188180446624756\n",
"epoch: 0, update in batch 291000/???, loss: 6.244111061096191\n",
"epoch: 0, update in batch 292000/???, loss: 7.199154853820801\n",
"epoch: 0, update in batch 293000/???, loss: 7.1866865158081055\n",
"epoch: 0, update in batch 294000/???, loss: 6.574115753173828\n",
"epoch: 0, update in batch 295000/???, loss: 6.487138271331787\n",
"epoch: 0, update in batch 296000/???, loss: 5.813161849975586\n",
"epoch: 0, update in batch 297000/???, loss: 6.159414291381836\n",
"epoch: 0, update in batch 298000/???, loss: 7.256616115570068\n",
"epoch: 0, update in batch 299000/???, loss: 7.511231899261475\n",
"epoch: 0, update in batch 300000/???, loss: 6.148821830749512\n",
"epoch: 0, update in batch 301000/???, loss: 7.108969211578369\n",
"epoch: 0, update in batch 302000/???, loss: 6.528176307678223\n",
"epoch: 0, update in batch 303000/???, loss: 6.276839256286621\n",
"epoch: 0, update in batch 304000/???, loss: 6.484020233154297\n",
"epoch: 0, update in batch 305000/???, loss: 6.38557767868042\n",
"epoch: 0, update in batch 306000/???, loss: 7.068814754486084\n",
"epoch: 0, update in batch 307000/???, loss: 5.844017505645752\n",
"epoch: 0, update in batch 308000/???, loss: 4.25785493850708\n",
"epoch: 0, update in batch 309000/???, loss: 6.709985256195068\n",
"epoch: 0, update in batch 310000/???, loss: 6.543104648590088\n",
"epoch: 0, update in batch 311000/???, loss: 6.675828456878662\n",
"epoch: 0, update in batch 312000/???, loss: 5.82969856262207\n",
"epoch: 0, update in batch 313000/???, loss: 6.05246639251709\n",
"epoch: 0, update in batch 314000/???, loss: 7.2366156578063965\n",
"epoch: 0, update in batch 315000/???, loss: 5.039820194244385\n",
"epoch: 0, update in batch 316000/???, loss: 5.943173885345459\n",
"epoch: 0, update in batch 317000/???, loss: 6.2509002685546875\n",
"epoch: 0, update in batch 318000/???, loss: 6.451228141784668\n",
"epoch: 0, update in batch 319000/???, loss: 6.6381049156188965\n",
"epoch: 0, update in batch 320000/???, loss: 6.570329189300537\n",
"epoch: 0, update in batch 321000/???, loss: 5.376622200012207\n",
"epoch: 0, update in batch 322000/???, loss: 6.487462520599365\n",
"epoch: 0, update in batch 323000/???, loss: 6.676497459411621\n",
"epoch: 0, update in batch 324000/???, loss: 6.283420562744141\n",
"epoch: 0, update in batch 325000/???, loss: 6.164648532867432\n",
"epoch: 0, update in batch 326000/???, loss: 6.839153289794922\n",
"epoch: 0, update in batch 327000/???, loss: 6.435141086578369\n",
"epoch: 0, update in batch 328000/???, loss: 6.160590171813965\n",
"epoch: 0, update in batch 329000/???, loss: 5.876160621643066\n",
"epoch: 0, update in batch 330000/???, loss: 6.47445011138916\n",
"epoch: 0, update in batch 331000/???, loss: 6.294231414794922\n",
"epoch: 0, update in batch 332000/???, loss: 6.099027156829834\n",
"epoch: 0, update in batch 333000/???, loss: 6.986542701721191\n",
"epoch: 0, update in batch 334000/???, loss: 7.018263816833496\n",
"epoch: 0, update in batch 335000/???, loss: 6.906959533691406\n",
"epoch: 0, update in batch 336000/???, loss: 6.12356424331665\n",
"epoch: 0, update in batch 337000/???, loss: 6.316069602966309\n",
"epoch: 0, update in batch 338000/???, loss: 6.908566474914551\n",
"epoch: 0, update in batch 339000/???, loss: 5.628839492797852\n",
"epoch: 0, update in batch 340000/???, loss: 7.069979667663574\n",
"epoch: 0, update in batch 341000/???, loss: 5.350735187530518\n",
"epoch: 0, update in batch 342000/???, loss: 5.377245903015137\n",
"epoch: 0, update in batch 343000/???, loss: 5.2340989112854\n",
"epoch: 0, update in batch 344000/???, loss: 6.087491512298584\n",
"epoch: 0, update in batch 345000/???, loss: 6.162985801696777\n",
"epoch: 0, update in batch 346000/???, loss: 5.655491828918457\n",
"epoch: 0, update in batch 347000/???, loss: 5.311842918395996\n",
"epoch: 0, update in batch 348000/???, loss: 7.577170372009277\n",
"epoch: 0, update in batch 349000/???, loss: 6.730460166931152\n",
"epoch: 0, update in batch 350000/???, loss: 6.782231330871582\n",
"epoch: 0, update in batch 351000/???, loss: 6.789486885070801\n",
"epoch: 0, update in batch 352000/???, loss: 5.473587989807129\n",
"epoch: 0, update in batch 353000/???, loss: 5.531443119049072\n",
"epoch: 0, update in batch 354000/???, loss: 7.220989227294922\n",
"epoch: 0, update in batch 355000/???, loss: 5.954288005828857\n",
"epoch: 0, update in batch 356000/???, loss: 4.112783432006836\n",
"epoch: 0, update in batch 357000/???, loss: 5.409672737121582\n",
"epoch: 0, update in batch 358000/???, loss: 6.408724784851074\n",
"epoch: 0, update in batch 359000/???, loss: 6.744941711425781\n",
"epoch: 0, update in batch 360000/???, loss: 6.218225479125977\n",
"epoch: 0, update in batch 361000/???, loss: 6.071394920349121\n",
"epoch: 0, update in batch 362000/???, loss: 6.137121677398682\n",
"epoch: 0, update in batch 363000/???, loss: 5.876864433288574\n",
"epoch: 0, update in batch 364000/???, loss: 7.715426445007324\n",
"epoch: 0, update in batch 365000/???, loss: 6.217362880706787\n",
"epoch: 0, update in batch 366000/???, loss: 6.741396903991699\n",
"epoch: 0, update in batch 367000/???, loss: 6.4564313888549805\n",
"epoch: 0, update in batch 368000/???, loss: 6.994439601898193\n",
"epoch: 0, update in batch 369000/???, loss: 6.061278820037842\n",
"epoch: 0, update in batch 370000/???, loss: 4.894576549530029\n",
"epoch: 0, update in batch 371000/???, loss: 6.351264953613281\n",
"epoch: 0, update in batch 372000/???, loss: 6.826904296875\n",
"epoch: 0, update in batch 373000/???, loss: 6.090312480926514\n",
"epoch: 0, update in batch 374000/???, loss: 5.797528266906738\n",
"epoch: 0, update in batch 375000/???, loss: 7.3235673904418945\n",
"epoch: 0, update in batch 376000/???, loss: 5.5752973556518555\n",
"epoch: 0, update in batch 377000/???, loss: 6.29438591003418\n",
"epoch: 0, update in batch 378000/???, loss: 5.238917827606201\n",
"epoch: 0, update in batch 379000/???, loss: 5.542972564697266\n",
"epoch: 0, update in batch 380000/???, loss: 6.5024614334106445\n",
"epoch: 0, update in batch 381000/???, loss: 6.918997287750244\n",
"epoch: 0, update in batch 382000/???, loss: 5.694029331207275\n",
"epoch: 0, update in batch 383000/???, loss: 7.109190940856934\n",
"epoch: 0, update in batch 384000/???, loss: 5.214654445648193\n",
"epoch: 0, update in batch 385000/???, loss: 7.055975437164307\n",
"epoch: 0, update in batch 386000/???, loss: 6.443846225738525\n",
"epoch: 0, update in batch 387000/???, loss: 5.544674873352051\n",
"epoch: 0, update in batch 388000/???, loss: 6.936171531677246\n",
"epoch: 0, update in batch 389000/???, loss: 6.646860599517822\n",
"epoch: 0, update in batch 390000/???, loss: 6.193584442138672\n",
"epoch: 0, update in batch 391000/???, loss: 5.9077558517456055\n",
"epoch: 0, update in batch 392000/???, loss: 5.029908657073975\n",
"epoch: 0, update in batch 393000/???, loss: 6.725222587585449\n",
"epoch: 0, update in batch 394000/???, loss: 6.814855098724365\n",
"epoch: 0, update in batch 395000/???, loss: 7.396543979644775\n",
"epoch: 0, update in batch 396000/???, loss: 6.993375301361084\n",
"epoch: 0, update in batch 397000/???, loss: 6.224326133728027\n",
"epoch: 0, update in batch 398000/???, loss: 6.301025390625\n",
"epoch: 0, update in batch 399000/???, loss: 6.707190036773682\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch: 0, update in batch 400000/???, loss: 6.7646660804748535\n",
"epoch: 0, update in batch 401000/???, loss: 5.308177471160889\n",
"epoch: 0, update in batch 402000/???, loss: 8.35996150970459\n",
"epoch: 0, update in batch 403000/???, loss: 5.825610160827637\n",
"epoch: 0, update in batch 404000/???, loss: 6.310220718383789\n",
"epoch: 0, update in batch 405000/???, loss: 5.759210109710693\n",
"epoch: 0, update in batch 406000/???, loss: 6.32699728012085\n",
"epoch: 0, update in batch 407000/???, loss: 5.659378528594971\n",
"epoch: 0, update in batch 408000/???, loss: 6.216103553771973\n",
"epoch: 0, update in batch 409000/???, loss: 5.666914463043213\n",
"epoch: 0, update in batch 410000/???, loss: 6.419122219085693\n",
"epoch: 0, update in batch 411000/???, loss: 5.372750282287598\n",
"epoch: 0, update in batch 412000/???, loss: 6.839580535888672\n",
"epoch: 0, update in batch 413000/???, loss: 6.7682647705078125\n",
"epoch: 0, update in batch 414000/???, loss: 5.951648235321045\n",
"epoch: 0, update in batch 415000/???, loss: 6.181953430175781\n",
"epoch: 0, update in batch 416000/???, loss: 5.475704669952393\n",
"epoch: 0, update in batch 417000/???, loss: 6.383082866668701\n",
"epoch: 0, update in batch 418000/???, loss: 6.8107590675354\n",
"epoch: 0, update in batch 419000/???, loss: 5.753104209899902\n",
"epoch: 0, update in batch 420000/???, loss: 5.320840835571289\n",
"epoch: 0, update in batch 421000/???, loss: 7.377203464508057\n",
"epoch: 0, update in batch 422000/???, loss: 6.5706048011779785\n",
"epoch: 0, update in batch 423000/???, loss: 5.032872676849365\n",
"epoch: 0, update in batch 424000/???, loss: 5.781243324279785\n",
"epoch: 0, update in batch 425000/???, loss: 6.160118579864502\n"
]
}
],
"source": [
"train(train_dataset, model, 1, 64)"
]
},
{
"cell_type": "code",
"execution_count": 58,
"id": "36a1b802",
"metadata": {},
"outputs": [],
"source": [
"def predict_probs(tokens):\n",
" model.eval()\n",
" state_h = model.init_state(len(tokens))\n",
"\n",
" x = torch.tensor([[train_dataset.key_to_index[w] if w in key_to_index else train_dataset.key_to_index['<unk>'] for w in tokens]]).to(device)\n",
" y_pred, state_h = model(x)\n",
"\n",
" last_word_logits = y_pred[0][-1]\n",
" probs = torch.nn.functional.softmax(last_word_logits, dim=0).detach().cpu().numpy()\n",
" word_index = np.random.choice(len(last_word_logits), p=probs)\n",
" \n",
" top_words = []\n",
" for index in range(len(probs)):\n",
" if len(top_words) < 30:\n",
" top_words.append((probs[index], [index]))\n",
" else:\n",
" worst_word = None\n",
" for word in top_words:\n",
" if not worst_word:\n",
" worst_word = word\n",
" else:\n",
" if word[0] < worst_word[0]:\n",
" worst_word = word\n",
" if worst_word[0] < probs[index] and index != len(probs) - 1:\n",
" top_words.remove(worst_word)\n",
" top_words.append((probs[index], [index]))\n",
" \n",
" prediction = ''\n",
" sum_prob = 0.0\n",
" for word in top_words:\n",
" sum_prob += word[0]\n",
" word_index = word[0]\n",
" word_text = index_to_key[word[1][0]]\n",
" prediction += f'{word_text}:{word_index} '\n",
" prediction += f':{1 - sum_prob}'\n",
" \n",
" return prediction"
]
},
{
"cell_type": "code",
"execution_count": 56,
"id": "155636b5",
"metadata": {},
"outputs": [],
"source": [
"dev_data = pd.read_csv('dev-0/in.tsv.xz', sep='\\t', on_bad_lines='skip', header=None, quoting=csv.QUOTE_NONE)\n",
"test_data = pd.read_csv('test-A/in.tsv.xz', sep='\\t', on_bad_lines='skip', header=None, quoting=csv.QUOTE_NONE)"
]
},
{
"cell_type": "code",
"execution_count": 59,
"id": "99b1d944",
"metadata": {},
"outputs": [],
"source": [
"with open('dev-0/out.tsv', 'w') as file:\n",
" for index, row in dev_data.iterrows():\n",
" left_text = clean_text(str(row[6]))\n",
" left_words = word_tokenize(left_text)\n",
" if len(left_words) < 6:\n",
" prediction = ':1.0'\n",
" else:\n",
" prediction = predict_probs(left_words[-5:])\n",
" file.write(prediction + '\\n')"
]
},
{
"cell_type": "code",
"execution_count": 60,
"id": "186c3269",
"metadata": {},
"outputs": [],
"source": [
"with open('test-A/out.tsv', 'w') as file:\n",
" for index, row in test_data.iterrows():\n",
" left_text = clean_text(str(row[6]))\n",
" left_words = word_tokenize(left_text)\n",
" if len(left_words) < 6:\n",
" prediction = ':1.0'\n",
" else:\n",
" prediction = predict_probs(left_words[-5:])\n",
" file.write(prediction + '\\n')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.2"
}
},
"nbformat": 4,
"nbformat_minor": 5
}