challenging-america-word-ga.../run.py
2022-06-06 11:25:46 +02:00

718 lines
36 KiB
Python

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "46fe9a72-d787-46ab-a9df-c6732c173a26",
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"import regex as re\n",
"import csv\n",
"import torch\n",
"from torch import nn\n",
"from gensim.models import Word2Vec\n",
"from nltk.tokenize import word_tokenize"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "6fd72918-57e4-44f2-a0e1-a71f322df5f7",
"metadata": {},
"outputs": [],
"source": [
"def clean_text(text):\n",
" text = text.lower().replace('-\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\n', '').replace('\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\n', ' ')\n",
" text = re.sub(r'\\p{P}', '', text)\n",
" text = text.replace(\"'t\", \" not\").replace(\"'s\", \" is\").replace(\"'ll\", \" will\").replace(\"'m\", \" am\").replace(\"'ve\", \" have\")\n",
"\n",
" return text"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "87dc8a73-f089-4551-b0f0-dedefd0b5a05",
"metadata": {},
"outputs": [],
"source": [
"train_data = pd.read_csv('train/in.tsv.xz', sep='\\t', on_bad_lines='skip', header=None, quoting=csv.QUOTE_NONE)\n",
"train_labels = pd.read_csv('train/expected.tsv', sep='\\t', on_bad_lines='skip', header=None, quoting=csv.QUOTE_NONE)\n",
"\n",
"train_data = train_data[[6, 7]]\n",
"train_data = pd.concat([train_data, train_labels], axis=1)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "1d9244d4-1dd7-4c02-8e22-6ea5feee9b26",
"metadata": {},
"outputs": [],
"source": [
"class TrainCorpus:\n",
" def __init__(self, data):\n",
" self.data = data\n",
" \n",
" def __iter__(self):\n",
" for _, row in self.data.iterrows():\n",
" text = str(row[6]) + str(row[0]) + str(row[7])\n",
" text = clean_text(text)\n",
" yield word_tokenize(text)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "5b6f017b-990f-42e2-8ef5-ce40824ace61",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"81477\n"
]
}
],
"source": [
"train_sentences = TrainCorpus(train_data.head(80000))\n",
"w2v_model = Word2Vec(vector_size=100, min_count=10)\n",
"w2v_model.build_vocab(corpus_iterable=train_sentences)\n",
"\n",
"key_to_index = w2v_model.wv.key_to_index\n",
"index_to_key = w2v_model.wv.index_to_key\n",
"\n",
"index_to_key.append('<unk>')\n",
"key_to_index['<unk>'] = len(index_to_key) - 1\n",
"\n",
"vocab_size = len(index_to_key)\n",
"print(vocab_size)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "516767c7-ce51-4483-b99f-675cfd4fe99d",
"metadata": {},
"outputs": [],
"source": [
"class TrainDataset(torch.utils.data.IterableDataset):\n",
" def __init__(self, data, index_to_key, key_to_index, reversed=False):\n",
" self.reversed = reversed\n",
" self.data = data\n",
" self.index_to_key = index_to_key\n",
" self.key_to_index = key_to_index\n",
" self.vocab_size = len(key_to_index)\n",
"\n",
" def __iter__(self):\n",
" for _, row in self.data.iterrows():\n",
" text = str(row[6]) + str(row[0]) + str(row[7])\n",
" text = clean_text(text)\n",
" tokens = word_tokenize(text)\n",
" if self.reversed:\n",
" tokens = list(reversed(tokens))\n",
" for i in range(5, len(tokens), 1):\n",
" input_context = tokens[i-5:i]\n",
" target_context = tokens[i-4:i+1]\n",
" \n",
" input_embed = [self.key_to_index[word] if word in self.key_to_index else self.key_to_index['<unk>'] for word in input_context]\n",
" target_embed = [self.key_to_index[word] if word in self.key_to_index else self.key_to_index['<unk>'] for word in target_context]\n",
" \n",
" yield np.asarray(input_embed, dtype=np.int64), np.asarray(target_embed, dtype=np.int64)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "678d0388-a2b2-44ca-b686-dc133a0d16e5",
"metadata": {},
"outputs": [],
"source": [
"class Model(nn.Module):\n",
" def __init__(self, embed_size, vocab_size):\n",
" super(Model, self).__init__()\n",
" self.embed_size = embed_size\n",
" self.vocab_size = vocab_size\n",
" self.lstm_size = 128\n",
" self.num_layers = 2\n",
" \n",
" self.embed = nn.Embedding(num_embeddings=vocab_size, embedding_dim=self.embed_size)\n",
" self.lstm = nn.LSTM(input_size=self.embed_size, hidden_size=self.lstm_size, num_layers=self.num_layers, dropout=0.2)\n",
" self.fc = nn.Linear(self.lstm_size, vocab_size)\n",
"\n",
" def forward(self, x, prev_state = None):\n",
" embed = self.embed(x)\n",
" output, state = self.lstm(embed, prev_state)\n",
" logits = self.fc(output)\n",
" probs = torch.softmax(logits, dim=1)\n",
" return logits, state\n",
"\n",
" def init_state(self, sequence_length):\n",
" zeros = torch.zeros(self.num_layers, sequence_length, self.gru_size).to(device)\n",
" return (zeros, zeros)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "487df613-15cd-450e-a8b2-7e7cd879f8f7",
"metadata": {},
"outputs": [],
"source": [
"from torch.utils.data import DataLoader\n",
"from torch.optim import Adam\n",
"\n",
"def train(dataset, model, max_epochs, batch_size):\n",
" model.train()\n",
"\n",
" dataloader = DataLoader(dataset, batch_size=batch_size)\n",
" criterion = nn.CrossEntropyLoss()\n",
" optimizer = Adam(model.parameters(), lr=0.001)\n",
"\n",
" for epoch in range(max_epochs):\n",
" for batch, (x, y) in enumerate(dataloader):\n",
" optimizer.zero_grad()\n",
" \n",
" x = x.to(device)\n",
" y = y.to(device)\n",
" \n",
" y_pred, (state_h, state_c) = model(x)\n",
" loss = criterion(y_pred.transpose(1, 2), y)\n",
"\n",
" loss.backward()\n",
" optimizer.step()\n",
" \n",
" if batch % 100 == 0:\n",
" print(f'epoch: {epoch}, update in batch {batch}/???, loss: {loss.item()}')"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "79df35ad-cb19-4867-869c-b651455580ae",
"metadata": {},
"outputs": [],
"source": [
"torch.cuda.empty_cache()\n",
"device = 'cuda' if torch.cuda.is_available() else 'cpu'"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "e4654fd0-6c0a-47e3-b623-73c7c04ea194",
"metadata": {},
"outputs": [],
"source": [
"train_dataset_front = TrainDataset(train_data.head(8000), index_to_key, key_to_index, False)"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "7b5f09f5-88c4-44d4-ab8a-aa62ec8f70d5",
"metadata": {},
"outputs": [],
"source": [
"model_front = Model(100, vocab_size).to(device)"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "7e886972-4994-4bbd-aca8-2ab685c7b8db",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch: 0, update in batch 0/???, loss: 11.315739631652832\n",
"epoch: 0, update in batch 100/???, loss: 8.016324996948242\n",
"epoch: 0, update in batch 200/???, loss: 7.45602560043335\n",
"epoch: 0, update in batch 300/???, loss: 6.306332588195801\n",
"epoch: 0, update in batch 400/???, loss: 8.629552841186523\n",
"epoch: 0, update in batch 500/???, loss: 7.637443542480469\n",
"epoch: 0, update in batch 600/???, loss: 7.67318868637085\n",
"epoch: 0, update in batch 700/???, loss: 7.2209930419921875\n",
"epoch: 0, update in batch 800/???, loss: 7.739532470703125\n",
"epoch: 0, update in batch 900/???, loss: 7.219891548156738\n",
"epoch: 0, update in batch 1000/???, loss: 6.8804473876953125\n",
"epoch: 0, update in batch 1100/???, loss: 7.228173732757568\n",
"epoch: 0, update in batch 1200/???, loss: 6.513087272644043\n",
"epoch: 0, update in batch 1300/???, loss: 7.142991542816162\n",
"epoch: 0, update in batch 1400/???, loss: 7.711663246154785\n",
"epoch: 0, update in batch 1500/???, loss: 6.894327640533447\n",
"epoch: 0, update in batch 1600/???, loss: 7.723884582519531\n",
"epoch: 0, update in batch 1700/???, loss: 8.409640312194824\n",
"epoch: 0, update in batch 1800/???, loss: 6.570927619934082\n",
"epoch: 0, update in batch 1900/???, loss: 6.906421661376953\n",
"epoch: 0, update in batch 2000/???, loss: 7.197023868560791\n",
"epoch: 0, update in batch 2100/???, loss: 6.892503261566162\n",
"epoch: 0, update in batch 2200/???, loss: 7.109471321105957\n",
"epoch: 0, update in batch 2300/???, loss: 8.84702205657959\n",
"epoch: 0, update in batch 2400/???, loss: 7.394454002380371\n",
"epoch: 0, update in batch 2500/???, loss: 7.380859375\n",
"epoch: 0, update in batch 2600/???, loss: 6.635237693786621\n",
"epoch: 0, update in batch 2700/???, loss: 6.869620323181152\n",
"epoch: 0, update in batch 2800/???, loss: 6.656294822692871\n",
"epoch: 0, update in batch 2900/???, loss: 8.090291976928711\n",
"epoch: 0, update in batch 3000/???, loss: 7.012345314025879\n",
"epoch: 0, update in batch 3100/???, loss: 6.7099809646606445\n",
"epoch: 0, update in batch 3200/???, loss: 6.798626899719238\n",
"epoch: 0, update in batch 3300/???, loss: 6.510752201080322\n",
"epoch: 0, update in batch 3400/???, loss: 7.742552757263184\n",
"epoch: 0, update in batch 3500/???, loss: 7.3319292068481445\n",
"epoch: 0, update in batch 3600/???, loss: 8.022462844848633\n",
"epoch: 0, update in batch 3700/???, loss: 5.883602619171143\n",
"epoch: 0, update in batch 3800/???, loss: 6.235389232635498\n",
"epoch: 0, update in batch 3900/???, loss: 7.012289524078369\n",
"epoch: 0, update in batch 4000/???, loss: 7.005420684814453\n",
"epoch: 0, update in batch 4100/???, loss: 6.595402717590332\n",
"epoch: 0, update in batch 4200/???, loss: 6.7428154945373535\n",
"epoch: 0, update in batch 4300/???, loss: 6.358878135681152\n",
"epoch: 0, update in batch 4400/???, loss: 6.6188201904296875\n",
"epoch: 0, update in batch 4500/???, loss: 7.08281946182251\n",
"epoch: 0, update in batch 4600/???, loss: 5.705609321594238\n",
"epoch: 0, update in batch 4700/???, loss: 7.1878180503845215\n",
"epoch: 0, update in batch 4800/???, loss: 7.071160793304443\n",
"epoch: 0, update in batch 4900/???, loss: 6.768280029296875\n",
"epoch: 0, update in batch 5000/???, loss: 6.507267951965332\n",
"epoch: 0, update in batch 5100/???, loss: 6.6431379318237305\n",
"epoch: 0, update in batch 5200/???, loss: 6.719052314758301\n",
"epoch: 0, update in batch 5300/???, loss: 7.172060489654541\n",
"epoch: 0, update in batch 5400/???, loss: 5.98638916015625\n",
"epoch: 0, update in batch 5500/???, loss: 5.674165725708008\n",
"epoch: 0, update in batch 5600/???, loss: 5.612569808959961\n",
"epoch: 0, update in batch 5700/???, loss: 6.307109832763672\n",
"epoch: 0, update in batch 5800/???, loss: 5.382391452789307\n",
"epoch: 0, update in batch 5900/???, loss: 5.712988376617432\n",
"epoch: 0, update in batch 6000/???, loss: 6.371735572814941\n",
"epoch: 0, update in batch 6100/???, loss: 6.417542457580566\n",
"epoch: 0, update in batch 6200/???, loss: 7.14879846572876\n",
"epoch: 0, update in batch 6300/???, loss: 7.0701189041137695\n",
"epoch: 0, update in batch 6400/???, loss: 7.048495292663574\n",
"epoch: 0, update in batch 6500/???, loss: 7.3384833335876465\n",
"epoch: 0, update in batch 6600/???, loss: 6.561330318450928\n",
"epoch: 0, update in batch 6700/???, loss: 6.839573860168457\n",
"epoch: 0, update in batch 6800/???, loss: 6.5179548263549805\n",
"epoch: 0, update in batch 6900/???, loss: 7.246607303619385\n",
"epoch: 0, update in batch 7000/???, loss: 6.5699052810668945\n",
"epoch: 0, update in batch 7100/???, loss: 7.202715873718262\n",
"epoch: 0, update in batch 7200/???, loss: 6.1833648681640625\n",
"epoch: 0, update in batch 7300/???, loss: 5.977782249450684\n",
"epoch: 0, update in batch 7400/???, loss: 6.717446327209473\n",
"epoch: 0, update in batch 7500/???, loss: 6.574376583099365\n",
"epoch: 0, update in batch 7600/???, loss: 5.8418450355529785\n",
"epoch: 0, update in batch 7700/???, loss: 6.282655715942383\n",
"epoch: 0, update in batch 7800/???, loss: 6.065321922302246\n",
"epoch: 0, update in batch 7900/???, loss: 6.415077209472656\n",
"epoch: 0, update in batch 8000/???, loss: 6.482673645019531\n",
"epoch: 0, update in batch 8100/???, loss: 6.670407772064209\n",
"epoch: 0, update in batch 8200/???, loss: 6.799211025238037\n",
"epoch: 0, update in batch 8300/???, loss: 7.299313545227051\n",
"epoch: 0, update in batch 8400/???, loss: 7.42974328994751\n",
"epoch: 0, update in batch 8500/???, loss: 8.549559593200684\n",
"epoch: 0, update in batch 8600/???, loss: 6.794680118560791\n",
"epoch: 0, update in batch 8700/???, loss: 7.390380859375\n",
"epoch: 0, update in batch 8800/???, loss: 7.552660942077637\n",
"epoch: 0, update in batch 8900/???, loss: 6.663547515869141\n",
"epoch: 0, update in batch 9000/???, loss: 6.5236711502075195\n",
"epoch: 0, update in batch 9100/???, loss: 7.666424751281738\n",
"epoch: 0, update in batch 9200/???, loss: 6.479496955871582\n",
"epoch: 0, update in batch 9300/???, loss: 5.5056304931640625\n",
"epoch: 0, update in batch 9400/???, loss: 6.6904096603393555\n",
"epoch: 0, update in batch 9500/???, loss: 6.9318037033081055\n",
"epoch: 0, update in batch 9600/???, loss: 6.521365165710449\n",
"epoch: 0, update in batch 9700/???, loss: 6.376631736755371\n",
"epoch: 0, update in batch 9800/???, loss: 6.4104766845703125\n",
"epoch: 0, update in batch 9900/???, loss: 7.3995232582092285\n",
"epoch: 0, update in batch 10000/???, loss: 6.510337829589844\n",
"epoch: 0, update in batch 10100/???, loss: 6.2512407302856445\n",
"epoch: 0, update in batch 10200/???, loss: 6.048404216766357\n",
"epoch: 0, update in batch 10300/???, loss: 6.832150936126709\n",
"epoch: 0, update in batch 10400/???, loss: 6.7485456466674805\n",
"epoch: 0, update in batch 10500/???, loss: 5.385656833648682\n",
"epoch: 0, update in batch 10600/???, loss: 6.769070625305176\n",
"epoch: 0, update in batch 10700/???, loss: 6.857029914855957\n",
"epoch: 0, update in batch 10800/???, loss: 5.991332530975342\n",
"epoch: 0, update in batch 10900/???, loss: 6.5500006675720215\n",
"epoch: 0, update in batch 11000/???, loss: 6.951509952545166\n",
"epoch: 0, update in batch 11100/???, loss: 6.396986961364746\n",
"epoch: 0, update in batch 11200/???, loss: 6.639346122741699\n",
"epoch: 0, update in batch 11300/???, loss: 5.87351655960083\n",
"epoch: 0, update in batch 11400/???, loss: 5.996974945068359\n",
"epoch: 0, update in batch 11500/???, loss: 7.103158473968506\n",
"epoch: 0, update in batch 11600/???, loss: 6.429941654205322\n",
"epoch: 0, update in batch 11700/???, loss: 5.597273826599121\n",
"epoch: 0, update in batch 11800/???, loss: 7.112508296966553\n",
"epoch: 0, update in batch 11900/???, loss: 6.745194911956787\n",
"epoch: 0, update in batch 12000/???, loss: 7.47100305557251\n",
"epoch: 0, update in batch 12100/???, loss: 6.847914695739746\n",
"epoch: 0, update in batch 12200/???, loss: 6.876992702484131\n",
"epoch: 0, update in batch 12300/???, loss: 6.499053955078125\n",
"epoch: 0, update in batch 12400/???, loss: 7.196413993835449\n",
"epoch: 0, update in batch 12500/???, loss: 6.593430995941162\n",
"epoch: 0, update in batch 12600/???, loss: 6.368945121765137\n",
"epoch: 0, update in batch 12700/???, loss: 6.362246513366699\n",
"epoch: 0, update in batch 12800/???, loss: 7.209506034851074\n",
"epoch: 0, update in batch 12900/???, loss: 6.8092780113220215\n",
"epoch: 0, update in batch 13000/???, loss: 8.273663520812988\n",
"epoch: 0, update in batch 13100/???, loss: 7.061187744140625\n",
"epoch: 0, update in batch 13200/???, loss: 5.778809547424316\n",
"epoch: 0, update in batch 13300/???, loss: 5.650263786315918\n",
"epoch: 0, update in batch 13400/???, loss: 5.9032440185546875\n",
"epoch: 0, update in batch 13500/???, loss: 6.629636287689209\n",
"epoch: 0, update in batch 13600/???, loss: 6.577019691467285\n",
"epoch: 0, update in batch 13700/???, loss: 5.953114032745361\n",
"epoch: 0, update in batch 13800/???, loss: 6.630902290344238\n",
"epoch: 0, update in batch 13900/???, loss: 7.593966484069824\n",
"epoch: 0, update in batch 14000/???, loss: 6.636081695556641\n",
"epoch: 0, update in batch 14100/???, loss: 5.772985458374023\n",
"epoch: 0, update in batch 14200/???, loss: 5.907249450683594\n",
"epoch: 0, update in batch 14300/???, loss: 7.863391876220703\n",
"epoch: 0, update in batch 14400/???, loss: 7.275572776794434\n",
"epoch: 0, update in batch 14500/???, loss: 6.818984031677246\n",
"epoch: 0, update in batch 14600/???, loss: 6.0456342697143555\n",
"epoch: 0, update in batch 14700/???, loss: 6.281990051269531\n",
"epoch: 0, update in batch 14800/???, loss: 6.197850227355957\n",
"epoch: 0, update in batch 14900/???, loss: 5.851240634918213\n",
"epoch: 0, update in batch 15000/???, loss: 6.826748847961426\n",
"epoch: 0, update in batch 15100/???, loss: 7.2189483642578125\n",
"epoch: 0, update in batch 15200/???, loss: 6.609204292297363\n",
"epoch: 0, update in batch 15300/???, loss: 6.947709560394287\n",
"epoch: 0, update in batch 15400/???, loss: 6.604478359222412\n",
"epoch: 0, update in batch 15500/???, loss: 6.222006797790527\n",
"epoch: 0, update in batch 15600/???, loss: 6.515635013580322\n",
"epoch: 0, update in batch 15700/???, loss: 6.40108585357666\n",
"epoch: 0, update in batch 15800/???, loss: 6.36106014251709\n",
"epoch: 0, update in batch 15900/???, loss: 6.533608436584473\n",
"epoch: 0, update in batch 16000/???, loss: 6.662516117095947\n",
"epoch: 0, update in batch 16100/???, loss: 7.284195899963379\n",
"epoch: 0, update in batch 16200/???, loss: 6.6524176597595215\n",
"epoch: 0, update in batch 16300/???, loss: 6.430756568908691\n",
"epoch: 0, update in batch 16400/???, loss: 7.515387058258057\n",
"epoch: 0, update in batch 16500/???, loss: 6.938241481781006\n",
"epoch: 0, update in batch 16600/???, loss: 5.860864162445068\n",
"epoch: 0, update in batch 16700/???, loss: 6.451329231262207\n",
"epoch: 0, update in batch 16800/???, loss: 6.5510663986206055\n",
"epoch: 0, update in batch 16900/???, loss: 7.3591437339782715\n",
"epoch: 0, update in batch 17000/???, loss: 6.158746719360352\n",
"epoch: 0, update in batch 17100/???, loss: 7.202520847320557\n",
"epoch: 0, update in batch 17200/???, loss: 6.80673885345459\n",
"epoch: 0, update in batch 17300/???, loss: 6.698304653167725\n",
"epoch: 0, update in batch 17400/???, loss: 5.743161201477051\n",
"epoch: 0, update in batch 17500/???, loss: 6.518529415130615\n",
"epoch: 0, update in batch 17600/???, loss: 6.021708011627197\n",
"epoch: 0, update in batch 17700/???, loss: 6.354712963104248\n",
"epoch: 0, update in batch 17800/???, loss: 6.323357582092285\n",
"epoch: 0, update in batch 17900/???, loss: 6.61548376083374\n",
"epoch: 0, update in batch 18000/???, loss: 6.600308895111084\n",
"epoch: 0, update in batch 18100/???, loss: 6.794068336486816\n",
"epoch: 0, update in batch 18200/???, loss: 7.487390041351318\n",
"epoch: 0, update in batch 18300/???, loss: 5.973461627960205\n",
"epoch: 0, update in batch 18400/???, loss: 6.891515254974365\n",
"epoch: 0, update in batch 18500/???, loss: 5.897144317626953\n",
"epoch: 0, update in batch 18600/???, loss: 6.6016364097595215\n",
"epoch: 0, update in batch 18700/???, loss: 6.948650360107422\n",
"epoch: 0, update in batch 18800/???, loss: 7.221627235412598\n",
"epoch: 0, update in batch 18900/???, loss: 6.817994117736816\n",
"epoch: 0, update in batch 19000/???, loss: 5.730655193328857\n",
"epoch: 0, update in batch 19100/???, loss: 6.236818790435791\n",
"epoch: 0, update in batch 19200/???, loss: 7.178666114807129\n",
"epoch: 0, update in batch 19300/???, loss: 6.77465295791626\n",
"epoch: 0, update in batch 19400/???, loss: 6.996792793273926\n",
"epoch: 0, update in batch 19500/???, loss: 6.80951452255249\n",
"epoch: 0, update in batch 19600/???, loss: 7.1757965087890625\n",
"epoch: 0, update in batch 19700/???, loss: 8.400952339172363\n",
"epoch: 0, update in batch 19800/???, loss: 7.1904473304748535\n",
"epoch: 0, update in batch 19900/???, loss: 6.339241981506348\n",
"epoch: 0, update in batch 20000/???, loss: 7.078637599945068\n",
"epoch: 0, update in batch 20100/???, loss: 5.015235900878906\n",
"epoch: 0, update in batch 20200/???, loss: 6.763777732849121\n",
"epoch: 0, update in batch 20300/???, loss: 6.543915748596191\n",
"epoch: 0, update in batch 20400/???, loss: 6.027902603149414\n",
"epoch: 0, update in batch 20500/???, loss: 6.710694789886475\n",
"epoch: 0, update in batch 20600/???, loss: 6.800978660583496\n",
"epoch: 0, update in batch 20700/???, loss: 6.371827125549316\n",
"epoch: 0, update in batch 20800/???, loss: 5.952463626861572\n",
"epoch: 0, update in batch 20900/???, loss: 6.317960739135742\n",
"epoch: 0, update in batch 21000/???, loss: 7.178386688232422\n",
"epoch: 0, update in batch 21100/???, loss: 6.887454986572266\n",
"epoch: 0, update in batch 21200/???, loss: 6.468400478363037\n",
"epoch: 0, update in batch 21300/???, loss: 7.8383684158325195\n",
"epoch: 0, update in batch 21400/???, loss: 5.850740909576416\n",
"epoch: 0, update in batch 21500/???, loss: 6.065464973449707\n",
"epoch: 0, update in batch 21600/???, loss: 7.537625312805176\n",
"epoch: 0, update in batch 21700/???, loss: 6.095994472503662\n",
"epoch: 0, update in batch 21800/???, loss: 6.342766761779785\n",
"epoch: 0, update in batch 21900/???, loss: 5.810301780700684\n",
"epoch: 0, update in batch 22000/???, loss: 6.447206974029541\n",
"epoch: 0, update in batch 22100/???, loss: 7.0662946701049805\n",
"epoch: 0, update in batch 22200/???, loss: 6.535088539123535\n",
"epoch: 0, update in batch 22300/???, loss: 7.017588138580322\n",
"epoch: 0, update in batch 22400/???, loss: 5.067782402038574\n",
"epoch: 0, update in batch 22500/???, loss: 6.493170738220215\n",
"epoch: 0, update in batch 22600/???, loss: 5.642627716064453\n",
"epoch: 0, update in batch 22700/???, loss: 7.200662136077881\n",
"epoch: 0, update in batch 22800/???, loss: 6.137134075164795\n",
"epoch: 0, update in batch 22900/???, loss: 6.367280006408691\n",
"epoch: 0, update in batch 23000/???, loss: 7.458652496337891\n",
"epoch: 0, update in batch 23100/???, loss: 6.515708923339844\n",
"epoch: 0, update in batch 23200/???, loss: 7.526422023773193\n",
"epoch: 0, update in batch 23300/???, loss: 6.653852939605713\n",
"epoch: 0, update in batch 23400/???, loss: 6.737251281738281\n",
"epoch: 0, update in batch 23500/???, loss: 6.493605136871338\n",
"epoch: 0, update in batch 23600/???, loss: 6.132809638977051\n",
"epoch: 0, update in batch 23700/???, loss: 6.406940460205078\n",
"epoch: 0, update in batch 23800/???, loss: 6.84005880355835\n",
"epoch: 0, update in batch 23900/???, loss: 6.830739498138428\n",
"epoch: 0, update in batch 24000/???, loss: 5.862464427947998\n",
"epoch: 0, update in batch 24100/???, loss: 6.382696628570557\n",
"epoch: 0, update in batch 24200/???, loss: 5.722895622253418\n",
"epoch: 0, update in batch 24300/???, loss: 6.697083473205566\n",
"epoch: 0, update in batch 24400/???, loss: 6.56771183013916\n",
"epoch: 0, update in batch 24500/???, loss: 7.566462516784668\n",
"epoch: 0, update in batch 24600/???, loss: 6.217026710510254\n",
"epoch: 0, update in batch 24700/???, loss: 7.164259433746338\n",
"epoch: 0, update in batch 24800/???, loss: 6.460946083068848\n",
"epoch: 0, update in batch 24900/???, loss: 6.333778381347656\n",
"epoch: 0, update in batch 25000/???, loss: 6.522342681884766\n",
"epoch: 0, update in batch 25100/???, loss: 6.270648002624512\n",
"epoch: 0, update in batch 25200/???, loss: 7.118265628814697\n",
"epoch: 0, update in batch 25300/???, loss: 5.8695197105407715\n",
"epoch: 0, update in batch 25400/???, loss: 5.92995023727417\n",
"epoch: 0, update in batch 25500/???, loss: 6.202570915222168\n",
"epoch: 0, update in batch 25600/???, loss: 6.4268975257873535\n",
"epoch: 0, update in batch 25700/???, loss: 6.710567474365234\n",
"epoch: 0, update in batch 25800/???, loss: 6.130914688110352\n",
"epoch: 0, update in batch 25900/???, loss: 6.082686424255371\n",
"epoch: 0, update in batch 26000/???, loss: 6.111697196960449\n",
"epoch: 0, update in batch 26100/???, loss: 7.320557594299316\n",
"epoch: 0, update in batch 26200/???, loss: 6.227985858917236\n",
"epoch: 0, update in batch 26300/???, loss: 6.204974174499512\n",
"epoch: 0, update in batch 26400/???, loss: 6.658400058746338\n",
"epoch: 0, update in batch 26500/???, loss: 5.911742687225342\n",
"epoch: 0, update in batch 26600/???, loss: 6.891500949859619\n",
"epoch: 0, update in batch 26700/???, loss: 5.763737201690674\n",
"epoch: 0, update in batch 26800/???, loss: 5.757307529449463\n",
"epoch: 0, update in batch 26900/???, loss: 6.076601982116699\n",
"epoch: 0, update in batch 27000/???, loss: 6.193032264709473\n",
"epoch: 0, update in batch 27100/???, loss: 6.120661735534668\n",
"epoch: 0, update in batch 27200/???, loss: 6.5425519943237305\n",
"epoch: 0, update in batch 27300/???, loss: 6.511394500732422\n",
"epoch: 0, update in batch 27400/???, loss: 7.127263069152832\n",
"epoch: 0, update in batch 27500/???, loss: 6.134243488311768\n",
"epoch: 0, update in batch 27600/???, loss: 6.5747809410095215\n",
"epoch: 0, update in batch 27700/???, loss: 6.351634979248047\n",
"epoch: 0, update in batch 27800/???, loss: 5.589611530303955\n",
"epoch: 0, update in batch 27900/???, loss: 6.916817665100098\n",
"epoch: 0, update in batch 28000/???, loss: 5.711864948272705\n",
"epoch: 0, update in batch 28100/???, loss: 6.921398162841797\n",
"epoch: 0, update in batch 28200/???, loss: 6.785823822021484\n",
"epoch: 0, update in batch 28300/???, loss: 6.007838249206543\n",
"epoch: 0, update in batch 28400/???, loss: 6.338862419128418\n",
"epoch: 0, update in batch 28500/???, loss: 6.9078168869018555\n",
"epoch: 0, update in batch 28600/???, loss: 6.710842132568359\n",
"epoch: 0, update in batch 28700/???, loss: 6.592329502105713\n",
"epoch: 0, update in batch 28800/???, loss: 6.184128761291504\n",
"epoch: 0, update in batch 28900/???, loss: 6.209361553192139\n",
"epoch: 0, update in batch 29000/???, loss: 7.067984104156494\n",
"epoch: 0, update in batch 29100/???, loss: 6.479236602783203\n",
"epoch: 0, update in batch 29200/???, loss: 6.413198947906494\n",
"epoch: 0, update in batch 29300/???, loss: 6.638579368591309\n",
"epoch: 0, update in batch 29400/???, loss: 5.938233375549316\n",
"epoch: 0, update in batch 29500/???, loss: 6.8490891456604\n",
"epoch: 0, update in batch 29600/???, loss: 6.111110210418701\n",
"epoch: 0, update in batch 29700/???, loss: 6.959462642669678\n",
"epoch: 0, update in batch 29800/???, loss: 6.964720726013184\n",
"epoch: 0, update in batch 29900/???, loss: 6.2007527351379395\n",
"epoch: 0, update in batch 30000/???, loss: 6.803907871246338\n",
"epoch: 0, update in batch 30100/???, loss: 5.665301322937012\n",
"epoch: 0, update in batch 30200/???, loss: 6.913702487945557\n",
"epoch: 0, update in batch 30300/???, loss: 6.824265956878662\n",
"epoch: 0, update in batch 30400/???, loss: 6.131905555725098\n",
"epoch: 0, update in batch 30500/???, loss: 5.799595832824707\n",
"epoch: 0, update in batch 30600/???, loss: 6.846949100494385\n",
"epoch: 0, update in batch 30700/???, loss: 6.481771945953369\n",
"epoch: 0, update in batch 30800/???, loss: 6.5581254959106445\n",
"epoch: 0, update in batch 30900/???, loss: 6.111696720123291\n",
"epoch: 0, update in batch 31000/???, loss: 4.8547563552856445\n",
"epoch: 0, update in batch 31100/???, loss: 6.5503740310668945\n",
"epoch: 0, update in batch 31200/???, loss: 6.212404251098633\n",
"epoch: 0, update in batch 31300/???, loss: 5.761624336242676\n",
"epoch: 0, update in batch 31400/???, loss: 7.043508052825928\n",
"epoch: 0, update in batch 31500/???, loss: 8.301980018615723\n",
"epoch: 0, update in batch 31600/???, loss: 5.655745506286621\n",
"epoch: 0, update in batch 31700/???, loss: 7.116888999938965\n",
"epoch: 0, update in batch 31800/???, loss: 6.237078666687012\n",
"epoch: 0, update in batch 31900/???, loss: 6.990937232971191\n",
"epoch: 0, update in batch 32000/???, loss: 6.327075958251953\n",
"epoch: 0, update in batch 32100/???, loss: 6.831456184387207\n",
"epoch: 0, update in batch 32200/???, loss: 6.511493682861328\n",
"epoch: 0, update in batch 32300/???, loss: 6.719797611236572\n",
"epoch: 0, update in batch 32400/???, loss: 6.46258544921875\n",
"epoch: 0, update in batch 32500/???, loss: 7.349535942077637\n",
"epoch: 0, update in batch 32600/???, loss: 5.773186683654785\n",
"epoch: 0, update in batch 32700/???, loss: 6.072037696838379\n",
"epoch: 0, update in batch 32800/???, loss: 7.044579982757568\n",
"epoch: 0, update in batch 32900/???, loss: 6.290024757385254\n",
"epoch: 0, update in batch 33000/???, loss: 7.101686000823975\n",
"epoch: 0, update in batch 33100/???, loss: 6.590539455413818\n",
"epoch: 0, update in batch 33200/???, loss: 6.944089412689209\n",
"epoch: 0, update in batch 33300/???, loss: 6.6709442138671875\n",
"epoch: 0, update in batch 33400/???, loss: 7.119935035705566\n",
"epoch: 0, update in batch 33500/???, loss: 6.845646858215332\n",
"epoch: 0, update in batch 33600/???, loss: 6.941410064697266\n",
"epoch: 0, update in batch 33700/???, loss: 6.341822624206543\n",
"epoch: 0, update in batch 33800/???, loss: 6.98660945892334\n",
"epoch: 0, update in batch 33900/???, loss: 7.544371128082275\n",
"epoch: 0, update in batch 34000/???, loss: 6.844598293304443\n",
"epoch: 0, update in batch 34100/???, loss: 6.958268642425537\n",
"epoch: 0, update in batch 34200/???, loss: 6.6372880935668945\n"
]
}
],
"source": [
"train(train_dataset_front, model_front, 1, 64)"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "f0512b6e-098d-4264-b52e-1d9ce88311f0",
"metadata": {},
"outputs": [],
"source": [
"def predict_probs(left_tokens, right_tokens):\n",
" model_front.eval()\n",
"\n",
" x_left = torch.tensor([[train_dataset_front.key_to_index[w] if w in key_to_index else train_dataset_front.key_to_index['<unk>'] for w in left_tokens]]).to(device)\n",
" x_right = torch.tensor([[train_dataset_front.key_to_index[w] if w in key_to_index else train_dataset_front.key_to_index['<unk>'] for w in right_tokens]]).to(device)\n",
" y_pred_left, (state_h_left, state_c_left) = model_front(x_left)\n",
" y_pred_right, (state_h_right, state_c_right) = model_back(x_right)\n",
"\n",
" last_word_logits_left = y_pred_left[0][-1]\n",
" last_word_logits_right = y_pred_right[0][-1]\n",
" probs_left = torch.nn.functional.softmax(last_word_logits_left, dim=0).detach().cpu().numpy()\n",
" probs_right = torch.nn.functional.softmax(last_word_logits_right, dim=0).detach().cpu().numpy()\n",
" \n",
" probs = [np.mean(k) for k in zip(probs_left, probs_right)]\n",
" \n",
" top_words = []\n",
" for index in range(len(probs)):\n",
" if len(top_words) < 30:\n",
" top_words.append((probs[index], [index]))\n",
" else:\n",
" worst_word = None\n",
" for word in top_words:\n",
" if not worst_word:\n",
" worst_word = word\n",
" else:\n",
" if word[0] < worst_word[0]:\n",
" worst_word = word\n",
" if worst_word[0] < probs[index] and index != len(probs) - 1:\n",
" top_words.remove(worst_word)\n",
" top_words.append((probs[index], [index]))\n",
" \n",
" prediction = ''\n",
" sum_prob = 0.0\n",
" for word in top_words:\n",
" sum_prob += word[0]\n",
" word_index = word[0]\n",
" word_text = index_to_key[word[1][0]]\n",
" prediction += f'{word_text}:{word_index} '\n",
" prediction += f':{1 - sum_prob}'\n",
" \n",
" return prediction"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "90938496-416a-4618-bf88-eddb3cc6e8da",
"metadata": {},
"outputs": [],
"source": [
"dev_data = pd.read_csv('dev-0/in.tsv.xz', sep='\\t', on_bad_lines='skip', header=None, quoting=csv.QUOTE_NONE)\n",
"test_data = pd.read_csv('test-A/in.tsv.xz', sep='\\t', on_bad_lines='skip', header=None, quoting=csv.QUOTE_NONE)"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "a0d4bd42-0499-46af-ab0e-5dd23dacf47a",
"metadata": {},
"outputs": [],
"source": [
"with open('dev-0/out.tsv', 'w') as file:\n",
" for index, row in dev_data.iterrows():\n",
" left_text = clean_text(str(row[6]))\n",
" right_text = clean_text(str(row[7]))\n",
" left_words = word_tokenize(left_text)\n",
" right_words = word_tokenize(right_text)\n",
" right_words.reverse()\n",
" if len(left_words) < 6 or len(right_words) < 6:\n",
" prediction = ':1.0'\n",
" else:\n",
" prediction = predict_probs(left_words[-5:], right_words[-5:])\n",
" file.write(prediction + '\\n')"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "819be88f-6336-4853-8015-33c9f3392aa4",
"metadata": {},
"outputs": [],
"source": [
"with open('test-A/out.tsv', 'w') as file:\n",
" for index, row in test_data.iterrows():\n",
" left_text = clean_text(str(row[6]))\n",
" right_text = clean_text(str(row[7]))\n",
" left_words = word_tokenize(left_text)\n",
" right_words = word_tokenize(right_text)\n",
" right_words.reverse()\n",
" if len(left_words) < 6 or len(right_words) < 6:\n",
" prediction = ':1.0'\n",
" else:\n",
" prediction = predict_probs(left_words[-5:], right_words[-5:])\n",
" file.write(prediction + '\\n')"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8de69da7-999e-4653-8308-b9d7e6fe24c1",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.7"
}
},
"nbformat": 4,
"nbformat_minor": 5
}