{ "cells": [ { "cell_type": "code", "execution_count": 1, "id": "03de852a", "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "import numpy as np\n", "import regex as re\n", "import csv\n", "import torch\n", "from torch import nn\n", "from gensim.models import Word2Vec\n", "from nltk.tokenize import word_tokenize" ] }, { "cell_type": "code", "execution_count": 2, "id": "73497953", "metadata": {}, "outputs": [], "source": [ "torch.cuda.empty_cache()\n", "device = 'cuda' if torch.cuda.is_available() else 'cpu'" ] }, { "cell_type": "code", "execution_count": 3, "id": "4227ef55", "metadata": {}, "outputs": [], "source": [ "def clean_text(text):\n", " text = text.lower().replace('-\\\\\\\\\\\\\\\\n', '').replace('\\\\\\\\\\\\\\\\n', ' ')\n", " text = re.sub(r'\\p{P}', '', text)\n", " text = text.replace(\"'t\", \" not\").replace(\"'s\", \" is\").replace(\"'ll\", \" will\").replace(\"'m\", \" am\").replace(\"'ve\", \" have\")\n", "\n", " return text" ] }, { "cell_type": "code", "execution_count": 4, "id": "758cf94a", "metadata": {}, "outputs": [], "source": [ "train_data = pd.read_csv('train/in.tsv.xz', sep='\\t', on_bad_lines='skip', header=None, quoting=csv.QUOTE_NONE)\n", "train_labels = pd.read_csv('train/expected.tsv', sep='\\t', on_bad_lines='skip', header=None, quoting=csv.QUOTE_NONE)\n", "\n", "train_data = train_data[[6, 7]]\n", "train_data = pd.concat([train_data, train_labels], axis=1)" ] }, { "cell_type": "code", "execution_count": 5, "id": "384922c7", "metadata": {}, "outputs": [], "source": [ "class TrainCorpus:\n", " def __init__(self, data):\n", " self.data = data\n", " \n", " def __iter__(self):\n", " for _, row in self.data.iterrows():\n", " text = str(row[6]) + str(row[0]) + str(row[7])\n", " text = clean_text(text)\n", " yield word_tokenize(text)" ] }, { "cell_type": "code", "execution_count": 6, "id": "3da5f19b", "metadata": {}, "outputs": [], "source": [ "train_sentences = TrainCorpus(train_data.head(80000))\n", "w2v_model = Word2Vec(vector_size=100, min_count=10)" ] }, { "cell_type": "code", "execution_count": 7, "id": "183d43be", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "81477\n" ] } ], "source": [ "w2v_model.build_vocab(corpus_iterable=train_sentences)\n", "\n", "key_to_index = w2v_model.wv.key_to_index\n", "index_to_key = w2v_model.wv.index_to_key\n", "\n", "index_to_key.append('')\n", "key_to_index[''] = len(index_to_key) - 1\n", "\n", "vocab_size = len(index_to_key)\n", "print(vocab_size)" ] }, { "cell_type": "code", "execution_count": 8, "id": "e63dd9fe", "metadata": {}, "outputs": [], "source": [ "class TrainDataset(torch.utils.data.IterableDataset):\n", " def __init__(self, data, index_to_key, key_to_index, reversed=False):\n", " self.reversed = reversed\n", " self.data = data\n", " self.index_to_key = index_to_key\n", " self.key_to_index = key_to_index\n", " self.vocab_size = len(key_to_index)\n", "\n", " def __iter__(self):\n", " for _, row in self.data.iterrows():\n", " text = str(row[6]) + str(row[0]) + str(row[7])\n", " text = clean_text(text)\n", " tokens = word_tokenize(text)\n", " if self.reversed:\n", " tokens = list(reversed(tokens))\n", " for i in range(5, len(tokens), 1):\n", " input_context = tokens[i-5:i]\n", " target_context = tokens[i-4:i+1]\n", " #gap_word = tokens[i]\n", " \n", " input_embed = [self.key_to_index[word] if word in self.key_to_index else self.key_to_index[''] for word in input_context]\n", " target_embed = [self.key_to_index[word] if word in self.key_to_index else self.key_to_index[''] for word in target_context]\n", " #word_index = self.key_to_index[gap_word] if gap_word in self.key_to_index else self.key_to_index['']\n", " #word_embed = np.concatenate([np.zeros(word_index), np.ones(1), np.zeros(vocab_size - word_index - 1)])\n", " \n", " yield np.asarray(input_embed, dtype=np.int64), np.asarray(target_embed, dtype=np.int64)" ] }, { "cell_type": "code", "execution_count": 9, "id": "7c60ddc1", "metadata": {}, "outputs": [], "source": [ "class Model(nn.Module):\n", " def __init__(self, embed_size, vocab_size):\n", " super(Model, self).__init__()\n", " self.embed_size = embed_size\n", " self.vocab_size = vocab_size\n", " self.lstm_size = 128\n", " self.num_layers = 2\n", " \n", " self.embed = nn.Embedding(num_embeddings=vocab_size, embedding_dim=self.embed_size)\n", " self.lstm = nn.LSTM(input_size=self.embed_size, hidden_size=self.lstm_size, num_layers=self.num_layers, dropout=0.2)\n", " self.fc = nn.Linear(self.lstm_size, vocab_size)\n", "\n", " def forward(self, x, prev_state = None):\n", " embed = self.embed(x)\n", " output, state = self.lstm(embed, prev_state)\n", " logits = self.fc(output)\n", " probs = torch.softmax(logits, dim=1)\n", " return logits, state\n", "\n", " def init_state(self, sequence_length):\n", " zeros = torch.zeros(self.num_layers, sequence_length, self.gru_size).to(device)\n", " return (zeros, zeros)" ] }, { "cell_type": "code", "execution_count": 10, "id": "1c7b8fab", "metadata": {}, "outputs": [], "source": [ "from torch.utils.data import DataLoader\n", "from torch.optim import Adam\n", "\n", "def train(dataset, model, max_epochs, batch_size):\n", " model.train()\n", "\n", " dataloader = DataLoader(dataset, batch_size=batch_size)\n", " criterion = nn.CrossEntropyLoss()\n", " optimizer = Adam(model.parameters(), lr=0.001)\n", "\n", " for epoch in range(max_epochs):\n", " for batch, (x, y) in enumerate(dataloader):\n", " optimizer.zero_grad()\n", " \n", " x = x.to(device)\n", " y = y.to(device)\n", " \n", " y_pred, (state_h, state_c) = model(x)\n", " loss = criterion(y_pred.transpose(1, 2), y)\n", "\n", " loss.backward()\n", " optimizer.step()\n", " \n", " if batch % 1000 == 0:\n", " print(f'epoch: {epoch}, update in batch {batch}/???, loss: {loss.item()}')" ] }, { "cell_type": "code", "execution_count": 11, "id": "3531d21d", "metadata": {}, "outputs": [], "source": [ "train_dataset_front = TrainDataset(train_data.head(80000), index_to_key, key_to_index, False)\n", "train_dataset_back = TrainDataset(train_data.tail(80000), index_to_key, key_to_index, True)" ] }, { "cell_type": "code", "execution_count": 12, "id": "f72f1f6d", "metadata": {}, "outputs": [], "source": [ "model_front = Model(100, vocab_size).to(device)\n", "model_back = Model(100, vocab_size).to(device)" ] }, { "cell_type": "code", "execution_count": 13, "id": "d608d9fe", "metadata": { "scrolled": true }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch: 0, update in batch 0/???, loss: 11.314821243286133\n", "epoch: 0, update in batch 1000/???, loss: 6.876476287841797\n", "epoch: 0, update in batch 2000/???, loss: 7.133523464202881\n", "epoch: 0, update in batch 3000/???, loss: 6.979971885681152\n", "epoch: 0, update in batch 4000/???, loss: 7.018368721008301\n", "epoch: 0, update in batch 5000/???, loss: 6.494096279144287\n", "epoch: 0, update in batch 6000/???, loss: 6.448479652404785\n", "epoch: 0, update in batch 7000/???, loss: 6.526387691497803\n", "epoch: 0, update in batch 8000/???, loss: 6.536323547363281\n", "epoch: 0, update in batch 9000/???, loss: 6.4919538497924805\n", "epoch: 0, update in batch 10000/???, loss: 6.435188293457031\n", "epoch: 0, update in batch 11000/???, loss: 6.934823513031006\n", "epoch: 0, update in batch 12000/???, loss: 7.410381317138672\n", "epoch: 0, update in batch 13000/???, loss: 8.227864265441895\n", "epoch: 0, update in batch 14000/???, loss: 6.7139105796813965\n", "epoch: 0, update in batch 15000/???, loss: 6.82781457901001\n", "epoch: 0, update in batch 16000/???, loss: 6.637822151184082\n", "epoch: 0, update in batch 17000/???, loss: 6.2633233070373535\n", "epoch: 0, update in batch 18000/???, loss: 6.512040138244629\n", "epoch: 0, update in batch 19000/???, loss: 5.745478630065918\n", "epoch: 0, update in batch 20000/???, loss: 7.039064884185791\n", "epoch: 0, update in batch 21000/???, loss: 7.151158332824707\n", "epoch: 0, update in batch 22000/???, loss: 6.460148811340332\n", "epoch: 0, update in batch 23000/???, loss: 7.396632194519043\n", "epoch: 0, update in batch 24000/???, loss: 5.907363414764404\n", "epoch: 0, update in batch 25000/???, loss: 6.669890403747559\n", "epoch: 0, update in batch 26000/???, loss: 6.032290458679199\n", "epoch: 0, update in batch 27000/???, loss: 6.192468166351318\n", "epoch: 0, update in batch 28000/???, loss: 5.757508277893066\n", "epoch: 0, update in batch 29000/???, loss: 7.097552299499512\n", "epoch: 0, update in batch 30000/???, loss: 6.8356804847717285\n", "epoch: 0, update in batch 31000/???, loss: 4.938998699188232\n", "epoch: 0, update in batch 32000/???, loss: 6.34550142288208\n", "epoch: 0, update in batch 33000/???, loss: 7.154759883880615\n", "epoch: 0, update in batch 34000/???, loss: 6.8563055992126465\n", "epoch: 0, update in batch 35000/???, loss: 6.831148624420166\n", "epoch: 0, update in batch 36000/???, loss: 6.867754936218262\n", "epoch: 0, update in batch 37000/???, loss: 6.911463260650635\n", "epoch: 0, update in batch 38000/???, loss: 6.637528896331787\n", "epoch: 0, update in batch 39000/???, loss: 6.822340488433838\n", "epoch: 0, update in batch 40000/???, loss: 6.122499942779541\n", "epoch: 0, update in batch 41000/???, loss: 6.454296112060547\n", "epoch: 0, update in batch 42000/???, loss: 7.5895185470581055\n", "epoch: 0, update in batch 43000/???, loss: 5.775805473327637\n", "epoch: 0, update in batch 44000/???, loss: 5.973118305206299\n", "epoch: 0, update in batch 45000/???, loss: 5.7727460861206055\n", "epoch: 0, update in batch 46000/???, loss: 6.376847267150879\n", "epoch: 0, update in batch 47000/???, loss: 5.739894866943359\n", "epoch: 0, update in batch 48000/???, loss: 6.390743732452393\n", "epoch: 0, update in batch 49000/???, loss: 7.724233150482178\n", "epoch: 0, update in batch 50000/???, loss: 5.242608070373535\n", "epoch: 0, update in batch 51000/???, loss: 5.412053108215332\n", "epoch: 0, update in batch 52000/???, loss: 6.590373992919922\n", "epoch: 0, update in batch 53000/???, loss: 6.46323299407959\n", "epoch: 0, update in batch 54000/???, loss: 6.9850263595581055\n", "epoch: 0, update in batch 55000/???, loss: 7.3167219161987305\n", "epoch: 0, update in batch 56000/???, loss: 6.285423278808594\n", "epoch: 0, update in batch 57000/???, loss: 7.417998313903809\n", "epoch: 0, update in batch 58000/???, loss: 6.437861442565918\n", "epoch: 0, update in batch 59000/???, loss: 6.522177219390869\n", "epoch: 0, update in batch 60000/???, loss: 5.9156928062438965\n", "epoch: 0, update in batch 61000/???, loss: 4.946429252624512\n", "epoch: 0, update in batch 62000/???, loss: 6.633675575256348\n", "epoch: 0, update in batch 63000/???, loss: 7.357038974761963\n", "epoch: 0, update in batch 64000/???, loss: 5.774768352508545\n", "epoch: 0, update in batch 65000/???, loss: 6.289044380187988\n", "epoch: 0, update in batch 66000/???, loss: 6.127488136291504\n", "epoch: 0, update in batch 67000/???, loss: 5.059685230255127\n", "epoch: 0, update in batch 68000/???, loss: 6.5439910888671875\n", "epoch: 0, update in batch 69000/???, loss: 6.679286956787109\n", "epoch: 0, update in batch 70000/???, loss: 7.2232346534729\n", "epoch: 0, update in batch 71000/???, loss: 6.13685941696167\n", "epoch: 0, update in batch 72000/???, loss: 5.766592025756836\n", "epoch: 0, update in batch 73000/???, loss: 6.772070407867432\n", "epoch: 0, update in batch 74000/???, loss: 7.369122505187988\n", "epoch: 0, update in batch 75000/???, loss: 6.598935127258301\n", "epoch: 0, update in batch 76000/???, loss: 5.948511600494385\n", "epoch: 0, update in batch 77000/???, loss: 6.507765769958496\n", "epoch: 0, update in batch 78000/???, loss: 5.09373664855957\n", "epoch: 0, update in batch 79000/???, loss: 5.9862494468688965\n", "epoch: 0, update in batch 80000/???, loss: 6.106108665466309\n", "epoch: 0, update in batch 81000/???, loss: 5.2747578620910645\n", "epoch: 0, update in batch 82000/???, loss: 6.324326515197754\n", "epoch: 0, update in batch 83000/???, loss: 5.914392471313477\n", "epoch: 0, update in batch 84000/???, loss: 6.641409873962402\n", "epoch: 0, update in batch 85000/???, loss: 6.287321090698242\n", "epoch: 0, update in batch 86000/???, loss: 6.510883331298828\n", "epoch: 0, update in batch 87000/???, loss: 6.458550930023193\n", "epoch: 0, update in batch 88000/???, loss: 6.07730770111084\n", "epoch: 0, update in batch 89000/???, loss: 6.2387471199035645\n", "epoch: 0, update in batch 90000/???, loss: 5.63344669342041\n", "epoch: 0, update in batch 91000/???, loss: 6.277956962585449\n", "epoch: 0, update in batch 92000/???, loss: 6.841054439544678\n", "epoch: 0, update in batch 93000/???, loss: 6.458809852600098\n", "epoch: 0, update in batch 94000/???, loss: 7.471741676330566\n", "epoch: 0, update in batch 95000/???, loss: 6.461136817932129\n", "epoch: 0, update in batch 96000/???, loss: 5.718675136566162\n", "epoch: 0, update in batch 97000/???, loss: 4.4265007972717285\n", "epoch: 0, update in batch 98000/???, loss: 7.05142879486084\n", "epoch: 0, update in batch 99000/???, loss: 6.341854572296143\n", "epoch: 0, update in batch 100000/???, loss: 6.834918022155762\n", "epoch: 0, update in batch 101000/???, loss: 5.367598056793213\n", "epoch: 0, update in batch 102000/???, loss: 5.716221809387207\n", "epoch: 0, update in batch 103000/???, loss: 6.9465742111206055\n", "epoch: 0, update in batch 104000/???, loss: 5.976019382476807\n", "epoch: 0, update in batch 105000/???, loss: 6.125661849975586\n", "epoch: 0, update in batch 106000/???, loss: 6.724229335784912\n", "epoch: 0, update in batch 107000/???, loss: 6.446004390716553\n", "epoch: 0, update in batch 108000/???, loss: 6.4710845947265625\n", "epoch: 0, update in batch 109000/???, loss: 6.5926103591918945\n", "epoch: 0, update in batch 110000/???, loss: 6.966839790344238\n", "epoch: 0, update in batch 111000/???, loss: 7.263918876647949\n", "epoch: 0, update in batch 112000/???, loss: 6.7561750411987305\n", "epoch: 0, update in batch 113000/???, loss: 6.142555236816406\n", "epoch: 0, update in batch 114000/???, loss: 5.974082946777344\n", "epoch: 0, update in batch 115000/???, loss: 5.565796852111816\n", "epoch: 0, update in batch 116000/???, loss: 6.4826202392578125\n", "epoch: 0, update in batch 117000/???, loss: 5.643266201019287\n", "epoch: 0, update in batch 118000/???, loss: 6.360909461975098\n", "epoch: 0, update in batch 119000/???, loss: 5.4074201583862305\n", "epoch: 0, update in batch 120000/???, loss: 7.1339569091796875\n", "epoch: 0, update in batch 121000/???, loss: 6.786561012268066\n", "epoch: 0, update in batch 122000/???, loss: 6.329574108123779\n", "epoch: 0, update in batch 123000/???, loss: 7.21968936920166\n", "epoch: 0, update in batch 124000/???, loss: 5.351359844207764\n", "epoch: 0, update in batch 125000/???, loss: 7.962380886077881\n", "epoch: 0, update in batch 126000/???, loss: 6.351782321929932\n", "epoch: 0, update in batch 127000/???, loss: 6.8343048095703125\n", "epoch: 0, update in batch 128000/???, loss: 6.129800319671631\n", "epoch: 0, update in batch 129000/???, loss: 6.68627405166626\n", "epoch: 0, update in batch 130000/???, loss: 6.498664855957031\n", "epoch: 0, update in batch 131000/???, loss: 5.724549293518066\n", "epoch: 0, update in batch 132000/???, loss: 7.041095733642578\n", "epoch: 0, update in batch 133000/???, loss: 5.901988983154297\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "epoch: 0, update in batch 134000/???, loss: 6.055495262145996\n", "epoch: 0, update in batch 135000/???, loss: 6.363399982452393\n", "epoch: 0, update in batch 136000/???, loss: 7.45733642578125\n", "epoch: 0, update in batch 137000/???, loss: 6.960203647613525\n", "epoch: 0, update in batch 138000/???, loss: 6.986503601074219\n", "epoch: 0, update in batch 139000/???, loss: 5.7938127517700195\n", "epoch: 0, update in batch 140000/???, loss: 5.559916019439697\n", "epoch: 0, update in batch 141000/???, loss: 5.551616668701172\n", "epoch: 0, update in batch 142000/???, loss: 5.386819839477539\n", "epoch: 0, update in batch 143000/???, loss: 6.826618194580078\n", "epoch: 0, update in batch 144000/???, loss: 6.106345176696777\n", "epoch: 0, update in batch 145000/???, loss: 6.812024116516113\n", "epoch: 0, update in batch 146000/???, loss: 6.347486972808838\n", "epoch: 0, update in batch 147000/???, loss: 6.20189094543457\n", "epoch: 0, update in batch 148000/???, loss: 5.5717034339904785\n", "epoch: 0, update in batch 149000/???, loss: 6.884232521057129\n", "epoch: 0, update in batch 150000/???, loss: 6.8074846267700195\n", "epoch: 0, update in batch 151000/???, loss: 7.028794288635254\n", "epoch: 0, update in batch 152000/???, loss: 5.201214790344238\n", "epoch: 0, update in batch 153000/???, loss: 5.1864013671875\n", "epoch: 0, update in batch 154000/???, loss: 6.4473114013671875\n", "epoch: 0, update in batch 155000/???, loss: 4.9203643798828125\n", "epoch: 0, update in batch 156000/???, loss: 6.829309940338135\n", "epoch: 0, update in batch 157000/???, loss: 7.045801639556885\n", "epoch: 0, update in batch 158000/???, loss: 6.4073967933654785\n", "epoch: 0, update in batch 159000/???, loss: 6.494145393371582\n", "epoch: 0, update in batch 160000/???, loss: 6.682474613189697\n", "epoch: 0, update in batch 161000/???, loss: 5.125617980957031\n", "epoch: 0, update in batch 162000/???, loss: 5.915367126464844\n", "epoch: 0, update in batch 163000/???, loss: 6.4779157638549805\n", "epoch: 0, update in batch 164000/???, loss: 5.547584533691406\n", "epoch: 0, update in batch 165000/???, loss: 6.134579181671143\n", "epoch: 0, update in batch 166000/???, loss: 5.300144672393799\n", "epoch: 0, update in batch 167000/???, loss: 6.53488826751709\n", "epoch: 0, update in batch 168000/???, loss: 6.711917877197266\n", "epoch: 0, update in batch 169000/???, loss: 7.0150322914123535\n", "epoch: 0, update in batch 170000/???, loss: 5.681846618652344\n", "epoch: 0, update in batch 171000/???, loss: 6.583130836486816\n", "epoch: 0, update in batch 172000/???, loss: 6.411820411682129\n", "epoch: 0, update in batch 173000/???, loss: 5.725490093231201\n", "epoch: 0, update in batch 174000/???, loss: 6.651374816894531\n", "epoch: 0, update in batch 175000/???, loss: 5.800152778625488\n", "epoch: 0, update in batch 176000/???, loss: 6.862998962402344\n", "epoch: 0, update in batch 177000/???, loss: 6.668658256530762\n", "epoch: 0, update in batch 178000/???, loss: 6.519270896911621\n", "epoch: 0, update in batch 179000/???, loss: 6.716788291931152\n", "epoch: 0, update in batch 180000/???, loss: 6.675846099853516\n", "epoch: 0, update in batch 181000/???, loss: 6.598060607910156\n", "epoch: 0, update in batch 182000/???, loss: 6.638599395751953\n", "epoch: 0, update in batch 183000/???, loss: 5.693145275115967\n", "epoch: 0, update in batch 184000/???, loss: 5.175653457641602\n", "epoch: 0, update in batch 185000/???, loss: 6.659600734710693\n", "epoch: 0, update in batch 186000/???, loss: 5.782421112060547\n", "epoch: 0, update in batch 187000/???, loss: 6.1736297607421875\n", "epoch: 0, update in batch 188000/???, loss: 5.38541316986084\n", "epoch: 0, update in batch 189000/???, loss: 6.238187789916992\n", "epoch: 0, update in batch 190000/???, loss: 6.10030460357666\n", "epoch: 0, update in batch 191000/???, loss: 6.680960655212402\n", "epoch: 0, update in batch 192000/???, loss: 6.600944519042969\n", "epoch: 0, update in batch 193000/???, loss: 6.171700477600098\n", "epoch: 0, update in batch 194000/???, loss: 7.250021934509277\n", "epoch: 0, update in batch 195000/???, loss: 5.968771934509277\n", "epoch: 0, update in batch 196000/???, loss: 7.107605934143066\n", "epoch: 0, update in batch 197000/???, loss: 6.743283748626709\n", "epoch: 0, update in batch 198000/???, loss: 7.130635738372803\n", "epoch: 0, update in batch 199000/???, loss: 6.37470817565918\n", "epoch: 0, update in batch 200000/???, loss: 6.050590515136719\n", "epoch: 0, update in batch 201000/???, loss: 5.468177318572998\n", "epoch: 0, update in batch 202000/???, loss: 6.343471527099609\n", "epoch: 0, update in batch 203000/???, loss: 6.890538692474365\n", "epoch: 0, update in batch 204000/???, loss: 7.018721580505371\n", "epoch: 0, update in batch 205000/???, loss: 6.131939888000488\n", "epoch: 0, update in batch 206000/???, loss: 6.219918251037598\n", "epoch: 0, update in batch 207000/???, loss: 5.858460426330566\n", "epoch: 0, update in batch 208000/???, loss: 6.33021354675293\n", "epoch: 0, update in batch 209000/???, loss: 6.249329566955566\n", "epoch: 0, update in batch 210000/???, loss: 6.263474941253662\n", "epoch: 0, update in batch 211000/???, loss: 6.731234550476074\n", "epoch: 0, update in batch 212000/???, loss: 5.978096961975098\n", "epoch: 0, update in batch 213000/???, loss: 5.148629188537598\n", "epoch: 0, update in batch 214000/???, loss: 6.79285192489624\n", "epoch: 0, update in batch 215000/???, loss: 5.943106651306152\n", "epoch: 0, update in batch 216000/???, loss: 5.749272346496582\n", "epoch: 0, update in batch 217000/???, loss: 6.991009712219238\n", "epoch: 0, update in batch 218000/???, loss: 6.21205997467041\n", "epoch: 0, update in batch 219000/???, loss: 7.519427299499512\n", "epoch: 0, update in batch 220000/???, loss: 5.699267387390137\n", "epoch: 0, update in batch 221000/???, loss: 6.05304479598999\n", "epoch: 0, update in batch 222000/???, loss: 6.422593116760254\n", "epoch: 0, update in batch 223000/???, loss: 6.179877281188965\n", "epoch: 0, update in batch 224000/???, loss: 4.841546058654785\n", "epoch: 0, update in batch 225000/???, loss: 6.666176795959473\n", "epoch: 0, update in batch 226000/???, loss: 5.994054794311523\n", "epoch: 0, update in batch 227000/???, loss: 6.792928218841553\n", "epoch: 0, update in batch 228000/???, loss: 6.9571661949157715\n", "epoch: 0, update in batch 229000/???, loss: 6.198942184448242\n", "epoch: 0, update in batch 230000/???, loss: 5.944539546966553\n", "epoch: 0, update in batch 231000/???, loss: 6.188899040222168\n", "epoch: 0, update in batch 232000/???, loss: 5.826596260070801\n", "epoch: 0, update in batch 233000/???, loss: 5.728386878967285\n", "epoch: 0, update in batch 234000/???, loss: 7.6024885177612305\n", "epoch: 0, update in batch 235000/???, loss: 6.728615760803223\n", "epoch: 0, update in batch 236000/???, loss: 6.2461137771606445\n", "epoch: 0, update in batch 237000/???, loss: 6.3110551834106445\n", "epoch: 0, update in batch 238000/???, loss: 6.12617826461792\n", "epoch: 0, update in batch 239000/???, loss: 6.6068243980407715\n", "epoch: 0, update in batch 240000/???, loss: 7.015429496765137\n", "epoch: 0, update in batch 241000/???, loss: 8.444561004638672\n", "epoch: 0, update in batch 242000/???, loss: 7.289303779602051\n", "epoch: 0, update in batch 243000/???, loss: 6.260491371154785\n", "epoch: 0, update in batch 244000/???, loss: 7.60237979888916\n", "epoch: 0, update in batch 245000/???, loss: 6.295613765716553\n", "epoch: 0, update in batch 246000/???, loss: 5.929107666015625\n", "epoch: 0, update in batch 247000/???, loss: 5.835566997528076\n", "epoch: 0, update in batch 248000/???, loss: 5.837784290313721\n", "epoch: 0, update in batch 249000/???, loss: 5.972233772277832\n", "epoch: 0, update in batch 250000/???, loss: 6.0488996505737305\n", "epoch: 0, update in batch 251000/???, loss: 5.712280750274658\n", "epoch: 0, update in batch 252000/???, loss: 5.9513702392578125\n", "epoch: 0, update in batch 253000/???, loss: 5.636294364929199\n", "epoch: 0, update in batch 254000/???, loss: 5.91803503036499\n", "epoch: 0, update in batch 255000/???, loss: 7.285937309265137\n", "epoch: 0, update in batch 256000/???, loss: 6.4795637130737305\n", "epoch: 0, update in batch 257000/???, loss: 6.0709991455078125\n", "epoch: 0, update in batch 258000/???, loss: 5.8723649978637695\n", "epoch: 0, update in batch 259000/???, loss: 5.174002647399902\n", "epoch: 0, update in batch 260000/???, loss: 6.504033088684082\n", "epoch: 0, update in batch 261000/???, loss: 7.088961601257324\n", "epoch: 0, update in batch 262000/???, loss: 6.2242960929870605\n", "epoch: 0, update in batch 263000/???, loss: 5.970286846160889\n", "epoch: 0, update in batch 264000/???, loss: 5.961676597595215\n", "epoch: 0, update in batch 265000/???, loss: 6.170080661773682\n", "epoch: 0, update in batch 266000/???, loss: 5.477972507476807\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "epoch: 0, update in batch 267000/???, loss: 6.188825607299805\n", "epoch: 0, update in batch 268000/???, loss: 6.518698215484619\n", "epoch: 0, update in batch 269000/???, loss: 5.663434028625488\n", "epoch: 0, update in batch 270000/???, loss: 5.978742599487305\n", "epoch: 0, update in batch 271000/???, loss: 6.217379093170166\n", "epoch: 0, update in batch 272000/???, loss: 5.426600933074951\n", "epoch: 0, update in batch 273000/???, loss: 6.7220964431762695\n", "epoch: 0, update in batch 274000/???, loss: 4.276306629180908\n", "epoch: 0, update in batch 275000/???, loss: 5.420112609863281\n", "epoch: 0, update in batch 276000/???, loss: 5.934456825256348\n", "epoch: 0, update in batch 277000/???, loss: 7.186459541320801\n", "epoch: 0, update in batch 278000/???, loss: 6.126835823059082\n", "epoch: 0, update in batch 279000/???, loss: 5.727339267730713\n", "epoch: 0, update in batch 280000/???, loss: 5.725864410400391\n", "epoch: 0, update in batch 281000/???, loss: 5.47005033493042\n", "epoch: 0, update in batch 282000/???, loss: 6.217499732971191\n", "epoch: 0, update in batch 283000/???, loss: 6.022196292877197\n", "epoch: 0, update in batch 284000/???, loss: 5.932379722595215\n", "epoch: 0, update in batch 285000/???, loss: 6.321987628936768\n", "epoch: 0, update in batch 286000/???, loss: 7.480570316314697\n", "epoch: 0, update in batch 287000/???, loss: 5.169373512268066\n", "epoch: 0, update in batch 288000/???, loss: 6.301320552825928\n", "epoch: 0, update in batch 289000/???, loss: 6.4635009765625\n", "epoch: 0, update in batch 290000/???, loss: 6.8701887130737305\n", "epoch: 0, update in batch 291000/???, loss: 6.036175727844238\n", "epoch: 0, update in batch 292000/???, loss: 6.705732822418213\n", "epoch: 0, update in batch 293000/???, loss: 6.99608850479126\n", "epoch: 0, update in batch 294000/???, loss: 6.50225305557251\n", "epoch: 0, update in batch 295000/???, loss: 6.03929328918457\n", "epoch: 0, update in batch 296000/???, loss: 5.498082160949707\n", "epoch: 0, update in batch 297000/???, loss: 6.04677677154541\n", "epoch: 0, update in batch 298000/???, loss: 6.482898712158203\n", "epoch: 0, update in batch 299000/???, loss: 7.235076904296875\n", "epoch: 0, update in batch 300000/???, loss: 6.019383907318115\n", "epoch: 0, update in batch 301000/???, loss: 7.082001686096191\n", "epoch: 0, update in batch 302000/???, loss: 6.447659492492676\n", "epoch: 0, update in batch 303000/???, loss: 5.94022798538208\n", "epoch: 0, update in batch 304000/???, loss: 6.459266662597656\n", "epoch: 0, update in batch 305000/???, loss: 6.281588077545166\n", "epoch: 0, update in batch 306000/???, loss: 7.022011756896973\n", "epoch: 0, update in batch 307000/???, loss: 6.1802263259887695\n", "epoch: 0, update in batch 308000/???, loss: 4.189492225646973\n", "epoch: 0, update in batch 309000/???, loss: 6.7040696144104\n", "epoch: 0, update in batch 310000/???, loss: 6.589522361755371\n", "epoch: 0, update in batch 311000/???, loss: 6.243889808654785\n", "epoch: 0, update in batch 312000/???, loss: 5.490180015563965\n", "epoch: 0, update in batch 313000/???, loss: 5.9699201583862305\n", "epoch: 0, update in batch 314000/???, loss: 7.321981906890869\n", "epoch: 0, update in batch 315000/???, loss: 4.731215953826904\n", "epoch: 0, update in batch 316000/???, loss: 5.845946788787842\n", "epoch: 0, update in batch 317000/???, loss: 5.917788505554199\n", "epoch: 0, update in batch 318000/???, loss: 6.420014381408691\n", "epoch: 0, update in batch 319000/???, loss: 6.550830841064453\n", "epoch: 0, update in batch 320000/???, loss: 6.751360893249512\n", "epoch: 0, update in batch 321000/???, loss: 5.025134086608887\n", "epoch: 0, update in batch 322000/???, loss: 6.368621826171875\n", "epoch: 0, update in batch 323000/???, loss: 6.2042083740234375\n", "epoch: 0, update in batch 324000/???, loss: 6.173147678375244\n", "epoch: 0, update in batch 325000/???, loss: 5.865999221801758\n", "epoch: 0, update in batch 326000/???, loss: 6.844902992248535\n", "epoch: 0, update in batch 327000/???, loss: 6.080742359161377\n", "epoch: 0, update in batch 328000/???, loss: 5.41788387298584\n", "epoch: 0, update in batch 329000/???, loss: 5.831374645233154\n", "epoch: 0, update in batch 330000/???, loss: 6.4492506980896\n", "epoch: 0, update in batch 331000/???, loss: 6.220627784729004\n", "epoch: 0, update in batch 332000/???, loss: 5.880006313323975\n", "epoch: 0, update in batch 333000/???, loss: 6.806972503662109\n", "epoch: 0, update in batch 334000/???, loss: 7.165728569030762\n", "epoch: 0, update in batch 335000/???, loss: 6.322948932647705\n", "epoch: 0, update in batch 336000/???, loss: 6.206046104431152\n", "epoch: 0, update in batch 337000/???, loss: 6.097958564758301\n", "epoch: 0, update in batch 338000/???, loss: 6.7682952880859375\n", "epoch: 0, update in batch 339000/???, loss: 5.2390642166137695\n", "epoch: 0, update in batch 340000/???, loss: 6.913119316101074\n" ] } ], "source": [ "train(train_dataset_front, model_front, 1, 64)" ] }, { "cell_type": "code", "execution_count": 14, "id": "132d9157", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "epoch: 0, update in batch 0/???, loss: 11.3253755569458\n", "epoch: 0, update in batch 1000/???, loss: 5.709358215332031\n", "epoch: 0, update in batch 2000/???, loss: 7.989391326904297\n", "epoch: 0, update in batch 3000/???, loss: 6.578714847564697\n", "epoch: 0, update in batch 4000/???, loss: 7.051873207092285\n", "epoch: 0, update in batch 5000/???, loss: 6.85653018951416\n", "epoch: 0, update in batch 6000/???, loss: 6.812790870666504\n", "epoch: 0, update in batch 7000/???, loss: 6.9604010581970215\n", "epoch: 0, update in batch 8000/???, loss: 6.798591613769531\n", "epoch: 0, update in batch 9000/???, loss: 6.415241241455078\n", "epoch: 0, update in batch 10000/???, loss: 6.6636223793029785\n", "epoch: 0, update in batch 11000/???, loss: 6.593747138977051\n", "epoch: 0, update in batch 12000/???, loss: 6.914702415466309\n", "epoch: 0, update in batch 13000/???, loss: 5.542675971984863\n", "epoch: 0, update in batch 14000/???, loss: 6.5461883544921875\n", "epoch: 0, update in batch 15000/???, loss: 7.507067680358887\n", "epoch: 0, update in batch 16000/???, loss: 5.425755500793457\n", "epoch: 0, update in batch 17000/???, loss: 6.285205841064453\n", "epoch: 0, update in batch 18000/???, loss: 4.223124027252197\n", "epoch: 0, update in batch 19000/???, loss: 6.530254364013672\n", "epoch: 0, update in batch 20000/???, loss: 6.091847896575928\n", "epoch: 0, update in batch 21000/???, loss: 7.088344573974609\n", "epoch: 0, update in batch 22000/???, loss: 5.925537109375\n", "epoch: 0, update in batch 23000/???, loss: 6.3628082275390625\n", "epoch: 0, update in batch 24000/???, loss: 6.604581356048584\n", "epoch: 0, update in batch 25000/???, loss: 6.2706499099731445\n", "epoch: 0, update in batch 26000/???, loss: 6.114742755889893\n", "epoch: 0, update in batch 27000/???, loss: 5.686783790588379\n", "epoch: 0, update in batch 28000/???, loss: 5.5114521980285645\n", "epoch: 0, update in batch 29000/???, loss: 6.999403953552246\n", "epoch: 0, update in batch 30000/???, loss: 5.834499359130859\n", "epoch: 0, update in batch 31000/???, loss: 5.873156547546387\n", "epoch: 0, update in batch 32000/???, loss: 6.246962547302246\n", "epoch: 0, update in batch 33000/???, loss: 6.742733955383301\n", "epoch: 0, update in batch 34000/???, loss: 6.832881927490234\n", "epoch: 0, update in batch 35000/???, loss: 6.625868320465088\n", "epoch: 0, update in batch 36000/???, loss: 6.653105735778809\n", "epoch: 0, update in batch 37000/???, loss: 6.104651927947998\n", "epoch: 0, update in batch 38000/???, loss: 6.301898002624512\n", "epoch: 0, update in batch 39000/???, loss: 7.377936363220215\n", "epoch: 0, update in batch 40000/???, loss: 6.26895809173584\n", "epoch: 0, update in batch 41000/???, loss: 6.602926731109619\n", "epoch: 0, update in batch 42000/???, loss: 6.419803619384766\n", "epoch: 0, update in batch 43000/???, loss: 7.187136650085449\n", "epoch: 0, update in batch 44000/???, loss: 6.382015705108643\n", "epoch: 0, update in batch 45000/???, loss: 6.044090747833252\n", "epoch: 0, update in batch 46000/???, loss: 5.707688808441162\n", "epoch: 0, update in batch 47000/???, loss: 7.007757663726807\n", "epoch: 0, update in batch 48000/???, loss: 5.365390300750732\n", "epoch: 0, update in batch 49000/???, loss: 5.510242938995361\n", "epoch: 0, update in batch 50000/???, loss: 5.955991268157959\n", "epoch: 0, update in batch 51000/???, loss: 6.2313032150268555\n", "epoch: 0, update in batch 52000/???, loss: 8.19306468963623\n", "epoch: 0, update in batch 53000/???, loss: 6.345375061035156\n", "epoch: 0, update in batch 54000/???, loss: 7.044759273529053\n", "epoch: 0, update in batch 55000/???, loss: 6.2544779777526855\n", "epoch: 0, update in batch 56000/???, loss: 6.315605163574219\n", "epoch: 0, update in batch 57000/???, loss: 5.632706642150879\n", "epoch: 0, update in batch 58000/???, loss: 6.0897536277771\n", "epoch: 0, update in batch 59000/???, loss: 5.562952518463135\n", "epoch: 0, update in batch 60000/???, loss: 5.519134044647217\n", "epoch: 0, update in batch 61000/???, loss: 6.394771099090576\n", "epoch: 0, update in batch 62000/???, loss: 6.147246360778809\n", "epoch: 0, update in batch 63000/???, loss: 5.798914909362793\n", "epoch: 0, update in batch 64000/???, loss: 6.026059627532959\n", "epoch: 0, update in batch 65000/???, loss: 6.4533233642578125\n", "epoch: 0, update in batch 66000/???, loss: 6.383795738220215\n", "epoch: 0, update in batch 67000/???, loss: 6.466322898864746\n", "epoch: 0, update in batch 68000/???, loss: 6.8227715492248535\n", "epoch: 0, update in batch 69000/???, loss: 6.283398151397705\n", "epoch: 0, update in batch 70000/???, loss: 4.547608375549316\n", "epoch: 0, update in batch 71000/???, loss: 6.008975028991699\n", "epoch: 0, update in batch 72000/???, loss: 5.674825191497803\n", "epoch: 0, update in batch 73000/???, loss: 5.134644508361816\n", "epoch: 0, update in batch 74000/???, loss: 6.906868934631348\n", "epoch: 0, update in batch 75000/???, loss: 6.672898292541504\n", "epoch: 0, update in batch 76000/???, loss: 5.813290596008301\n", "epoch: 0, update in batch 77000/???, loss: 6.296219825744629\n", "epoch: 0, update in batch 78000/???, loss: 6.531443119049072\n", "epoch: 0, update in batch 79000/???, loss: 6.437461853027344\n", "epoch: 0, update in batch 80000/???, loss: 6.2280778884887695\n", "epoch: 0, update in batch 81000/???, loss: 6.805241584777832\n", "epoch: 0, update in batch 82000/???, loss: 7.044824123382568\n", "epoch: 0, update in batch 83000/???, loss: 7.348274230957031\n", "epoch: 0, update in batch 84000/???, loss: 5.826806545257568\n", "epoch: 0, update in batch 85000/???, loss: 5.474950313568115\n", "epoch: 0, update in batch 86000/???, loss: 6.497323036193848\n", "epoch: 0, update in batch 87000/???, loss: 5.88934850692749\n", "epoch: 0, update in batch 88000/???, loss: 5.371798038482666\n", "epoch: 0, update in batch 89000/???, loss: 6.093968391418457\n", "epoch: 0, update in batch 90000/???, loss: 6.115981578826904\n", "epoch: 0, update in batch 91000/???, loss: 6.504927158355713\n", "epoch: 0, update in batch 92000/???, loss: 6.239808082580566\n", "epoch: 0, update in batch 93000/???, loss: 5.384994983673096\n", "epoch: 0, update in batch 94000/???, loss: 6.422779083251953\n", "epoch: 0, update in batch 95000/???, loss: 7.163965702056885\n", "epoch: 0, update in batch 96000/???, loss: 6.44806432723999\n", "epoch: 0, update in batch 97000/???, loss: 6.153664588928223\n", "epoch: 0, update in batch 98000/???, loss: 5.9013776779174805\n", "epoch: 0, update in batch 99000/???, loss: 6.198166847229004\n", "epoch: 0, update in batch 100000/???, loss: 5.752341270446777\n", "epoch: 0, update in batch 101000/???, loss: 6.455883979797363\n", "epoch: 0, update in batch 102000/???, loss: 5.270313262939453\n", "epoch: 0, update in batch 103000/???, loss: 6.475237846374512\n", "epoch: 0, update in batch 104000/???, loss: 6.2444844245910645\n", "epoch: 0, update in batch 105000/???, loss: 6.1563720703125\n", "epoch: 0, update in batch 106000/???, loss: 6.12777853012085\n", "epoch: 0, update in batch 107000/???, loss: 6.449145317077637\n", "epoch: 0, update in batch 108000/???, loss: 6.515239715576172\n", "epoch: 0, update in batch 109000/???, loss: 5.6317644119262695\n", "epoch: 0, update in batch 110000/???, loss: 6.09606409072876\n", "epoch: 0, update in batch 111000/???, loss: 7.069797515869141\n", "epoch: 0, update in batch 112000/???, loss: 7.456076145172119\n", "epoch: 0, update in batch 113000/???, loss: 6.668386936187744\n", "epoch: 0, update in batch 114000/???, loss: 7.705430507659912\n", "epoch: 0, update in batch 115000/???, loss: 6.983656883239746\n", "epoch: 0, update in batch 116000/???, loss: 6.320417404174805\n", "epoch: 0, update in batch 117000/???, loss: 7.184473991394043\n", "epoch: 0, update in batch 118000/???, loss: 6.603268623352051\n", "epoch: 0, update in batch 119000/???, loss: 6.670085906982422\n", "epoch: 0, update in batch 120000/???, loss: 6.748586177825928\n", "epoch: 0, update in batch 121000/???, loss: 6.353959560394287\n", "epoch: 0, update in batch 122000/???, loss: 5.138751029968262\n", "epoch: 0, update in batch 123000/???, loss: 6.507109642028809\n", "epoch: 0, update in batch 124000/???, loss: 6.360246181488037\n", "epoch: 0, update in batch 125000/???, loss: 7.164086818695068\n", "epoch: 0, update in batch 126000/???, loss: 5.610747337341309\n", "epoch: 0, update in batch 127000/???, loss: 5.066179275512695\n", "epoch: 0, update in batch 128000/???, loss: 5.688697814941406\n", "epoch: 0, update in batch 129000/???, loss: 6.960330963134766\n", "epoch: 0, update in batch 130000/???, loss: 5.818534851074219\n", "epoch: 0, update in batch 131000/???, loss: 6.186715602874756\n", "epoch: 0, update in batch 132000/???, loss: 5.825492858886719\n", "epoch: 0, update in batch 133000/???, loss: 5.576340675354004\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "epoch: 0, update in batch 134000/???, loss: 5.503821849822998\n", "epoch: 0, update in batch 135000/???, loss: 6.428965091705322\n", "epoch: 0, update in batch 136000/???, loss: 5.102448463439941\n", "epoch: 0, update in batch 137000/???, loss: 6.239314556121826\n", "epoch: 0, update in batch 138000/???, loss: 6.028595447540283\n", "epoch: 0, update in batch 139000/???, loss: 6.407244682312012\n", "epoch: 0, update in batch 140000/???, loss: 5.597055912017822\n", "epoch: 0, update in batch 141000/???, loss: 5.823704719543457\n", "epoch: 0, update in batch 142000/???, loss: 6.665535926818848\n", "epoch: 0, update in batch 143000/???, loss: 5.5736894607543945\n", "epoch: 0, update in batch 144000/???, loss: 6.723180294036865\n", "epoch: 0, update in batch 145000/???, loss: 6.378345489501953\n", "epoch: 0, update in batch 146000/???, loss: 5.6936845779418945\n", "epoch: 0, update in batch 147000/???, loss: 5.761658668518066\n", "epoch: 0, update in batch 148000/???, loss: 5.580254077911377\n", "epoch: 0, update in batch 149000/???, loss: 5.733176231384277\n", "epoch: 0, update in batch 150000/???, loss: 6.901691436767578\n", "epoch: 0, update in batch 151000/???, loss: 6.5111589431762695\n", "epoch: 0, update in batch 152000/???, loss: 6.184727668762207\n", "epoch: 0, update in batch 153000/???, loss: 7.407107353210449\n", "epoch: 0, update in batch 154000/???, loss: 6.499199867248535\n", "epoch: 0, update in batch 155000/???, loss: 5.143393516540527\n", "epoch: 0, update in batch 156000/???, loss: 7.60940408706665\n", "epoch: 0, update in batch 157000/???, loss: 6.766045570373535\n", "epoch: 0, update in batch 158000/???, loss: 5.268759727478027\n", "epoch: 0, update in batch 159000/???, loss: 7.558129787445068\n", "epoch: 0, update in batch 160000/???, loss: 8.016000747680664\n", "epoch: 0, update in batch 161000/???, loss: 5.959166526794434\n", "epoch: 0, update in batch 162000/???, loss: 5.499085426330566\n", "epoch: 0, update in batch 163000/???, loss: 6.581662654876709\n", "epoch: 0, update in batch 164000/???, loss: 6.681334495544434\n", "epoch: 0, update in batch 165000/???, loss: 7.817207336425781\n", "epoch: 0, update in batch 166000/???, loss: 6.524381160736084\n", "epoch: 0, update in batch 167000/???, loss: 5.903864860534668\n", "epoch: 0, update in batch 168000/???, loss: 5.6087260246276855\n", "epoch: 0, update in batch 169000/???, loss: 5.742824554443359\n", "epoch: 0, update in batch 170000/???, loss: 6.129671096801758\n", "epoch: 0, update in batch 171000/???, loss: 5.879034519195557\n", "epoch: 0, update in batch 172000/???, loss: 6.322129249572754\n", "epoch: 0, update in batch 173000/???, loss: 6.805352210998535\n", "epoch: 0, update in batch 174000/???, loss: 7.162431240081787\n", "epoch: 0, update in batch 175000/???, loss: 6.123959541320801\n", "epoch: 0, update in batch 176000/???, loss: 7.544029235839844\n", "epoch: 0, update in batch 177000/???, loss: 5.4254021644592285\n", "epoch: 0, update in batch 178000/???, loss: 5.784268379211426\n", "epoch: 0, update in batch 179000/???, loss: 5.8633856773376465\n", "epoch: 0, update in batch 180000/???, loss: 6.556314945220947\n", "epoch: 0, update in batch 181000/???, loss: 5.215446472167969\n", "epoch: 0, update in batch 182000/???, loss: 6.079234600067139\n", "epoch: 0, update in batch 183000/???, loss: 7.234827995300293\n", "epoch: 0, update in batch 184000/???, loss: 5.249889373779297\n", "epoch: 0, update in batch 185000/???, loss: 5.083311080932617\n", "epoch: 0, update in batch 186000/???, loss: 6.061867713928223\n", "epoch: 0, update in batch 187000/???, loss: 6.060431480407715\n", "epoch: 0, update in batch 188000/???, loss: 5.572680950164795\n", "epoch: 0, update in batch 189000/???, loss: 5.991988182067871\n", "epoch: 0, update in batch 190000/???, loss: 6.521245002746582\n", "epoch: 0, update in batch 191000/???, loss: 5.128615379333496\n", "epoch: 0, update in batch 192000/???, loss: 5.616750717163086\n", "epoch: 0, update in batch 193000/???, loss: 6.1465044021606445\n", "epoch: 0, update in batch 194000/???, loss: 5.93985652923584\n", "epoch: 0, update in batch 195000/???, loss: 6.268892765045166\n", "epoch: 0, update in batch 196000/???, loss: 5.928576469421387\n", "epoch: 0, update in batch 197000/???, loss: 5.257290363311768\n", "epoch: 0, update in batch 198000/???, loss: 6.6432952880859375\n", "epoch: 0, update in batch 199000/???, loss: 6.898074150085449\n", "epoch: 0, update in batch 200000/???, loss: 7.042447566986084\n", "epoch: 0, update in batch 201000/???, loss: 7.104043483734131\n", "epoch: 0, update in batch 202000/???, loss: 6.238812446594238\n", "epoch: 0, update in batch 203000/???, loss: 6.773525238037109\n", "epoch: 0, update in batch 204000/???, loss: 5.054592132568359\n", "epoch: 0, update in batch 205000/???, loss: 6.854428768157959\n", "epoch: 0, update in batch 206000/???, loss: 5.9983601570129395\n", "epoch: 0, update in batch 207000/???, loss: 5.236695766448975\n", "epoch: 0, update in batch 208000/???, loss: 6.086891174316406\n", "epoch: 0, update in batch 209000/???, loss: 6.134495258331299\n", "epoch: 0, update in batch 210000/???, loss: 6.52248477935791\n", "epoch: 0, update in batch 211000/???, loss: 6.028376579284668\n", "epoch: 0, update in batch 212000/???, loss: 6.140281677246094\n", "epoch: 0, update in batch 213000/???, loss: 6.066422462463379\n", "epoch: 0, update in batch 214000/???, loss: 6.868189334869385\n", "epoch: 0, update in batch 215000/???, loss: 6.641358852386475\n", "epoch: 0, update in batch 216000/???, loss: 6.818638801574707\n", "epoch: 0, update in batch 217000/???, loss: 6.40252685546875\n", "epoch: 0, update in batch 218000/???, loss: 5.561617851257324\n", "epoch: 0, update in batch 219000/???, loss: 6.434267997741699\n", "epoch: 0, update in batch 220000/???, loss: 6.33272123336792\n", "epoch: 0, update in batch 221000/???, loss: 5.75616979598999\n", "epoch: 0, update in batch 222000/???, loss: 6.477814674377441\n", "epoch: 0, update in batch 223000/???, loss: 5.259497165679932\n", "epoch: 0, update in batch 224000/???, loss: 5.8639655113220215\n", "epoch: 0, update in batch 225000/???, loss: 6.469706058502197\n", "epoch: 0, update in batch 226000/???, loss: 5.707249164581299\n", "epoch: 0, update in batch 227000/???, loss: 6.394181251525879\n", "epoch: 0, update in batch 228000/???, loss: 5.048886299133301\n", "epoch: 0, update in batch 229000/???, loss: 5.842928409576416\n", "epoch: 0, update in batch 230000/???, loss: 5.627688407897949\n", "epoch: 0, update in batch 231000/???, loss: 7.950299263000488\n", "epoch: 0, update in batch 232000/???, loss: 6.771368503570557\n", "epoch: 0, update in batch 233000/???, loss: 5.787235260009766\n", "epoch: 0, update in batch 234000/???, loss: 5.6070780754089355\n", "epoch: 0, update in batch 235000/???, loss: 6.060035705566406\n", "epoch: 0, update in batch 236000/???, loss: 6.894829750061035\n", "epoch: 0, update in batch 237000/???, loss: 5.672856330871582\n", "epoch: 0, update in batch 238000/???, loss: 5.054213523864746\n", "epoch: 0, update in batch 239000/???, loss: 6.484643459320068\n", "epoch: 0, update in batch 240000/???, loss: 5.800728797912598\n", "epoch: 0, update in batch 241000/???, loss: 5.148013591766357\n", "epoch: 0, update in batch 242000/???, loss: 5.529184818267822\n", "epoch: 0, update in batch 243000/???, loss: 5.959448337554932\n", "epoch: 0, update in batch 244000/???, loss: 6.762448787689209\n", "epoch: 0, update in batch 245000/???, loss: 4.907589912414551\n", "epoch: 0, update in batch 246000/???, loss: 6.275182723999023\n", "epoch: 0, update in batch 247000/???, loss: 5.7234015464782715\n", "epoch: 0, update in batch 248000/???, loss: 6.119207859039307\n", "epoch: 0, update in batch 249000/???, loss: 5.297057151794434\n", "epoch: 0, update in batch 250000/???, loss: 5.924614906311035\n", "epoch: 0, update in batch 251000/???, loss: 6.651083469390869\n", "epoch: 0, update in batch 252000/???, loss: 5.7164201736450195\n", "epoch: 0, update in batch 253000/???, loss: 6.105191230773926\n", "epoch: 0, update in batch 254000/???, loss: 5.791018486022949\n", "epoch: 0, update in batch 255000/???, loss: 6.659502983093262\n", "epoch: 0, update in batch 256000/???, loss: 5.613073348999023\n", "epoch: 0, update in batch 257000/???, loss: 7.501049041748047\n", "epoch: 0, update in batch 258000/???, loss: 6.043797492980957\n", "epoch: 0, update in batch 259000/???, loss: 7.3587327003479\n", "epoch: 0, update in batch 260000/???, loss: 6.276612281799316\n", "epoch: 0, update in batch 261000/???, loss: 6.445192813873291\n", "epoch: 0, update in batch 262000/???, loss: 5.0266547203063965\n", "epoch: 0, update in batch 263000/???, loss: 6.404935359954834\n", "epoch: 0, update in batch 264000/???, loss: 6.5042290687561035\n", "epoch: 0, update in batch 265000/???, loss: 6.880773067474365\n", "epoch: 0, update in batch 266000/???, loss: 6.3690643310546875\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "epoch: 0, update in batch 267000/???, loss: 6.055562973022461\n", "epoch: 0, update in batch 268000/???, loss: 5.796906471252441\n", "epoch: 0, update in batch 269000/???, loss: 5.654962539672852\n", "epoch: 0, update in batch 270000/???, loss: 6.574362277984619\n", "epoch: 0, update in batch 271000/???, loss: 6.256768226623535\n", "epoch: 0, update in batch 272000/???, loss: 6.8345208168029785\n", "epoch: 0, update in batch 273000/???, loss: 6.066469669342041\n", "epoch: 0, update in batch 274000/???, loss: 6.625809669494629\n", "epoch: 0, update in batch 275000/???, loss: 4.762896537780762\n", "epoch: 0, update in batch 276000/???, loss: 6.019833564758301\n", "epoch: 0, update in batch 277000/???, loss: 6.227939605712891\n", "epoch: 0, update in batch 278000/???, loss: 7.046879768371582\n", "epoch: 0, update in batch 279000/???, loss: 6.068551540374756\n", "epoch: 0, update in batch 280000/???, loss: 6.454771995544434\n", "epoch: 0, update in batch 281000/???, loss: 3.9379985332489014\n", "epoch: 0, update in batch 282000/???, loss: 5.615240097045898\n", "epoch: 0, update in batch 283000/???, loss: 5.7963151931762695\n", "epoch: 0, update in batch 284000/???, loss: 6.064437389373779\n", "epoch: 0, update in batch 285000/???, loss: 6.668734073638916\n", "epoch: 0, update in batch 286000/???, loss: 6.776829719543457\n", "epoch: 0, update in batch 287000/???, loss: 6.170516014099121\n", "epoch: 0, update in batch 288000/???, loss: 4.840399742126465\n", "epoch: 0, update in batch 289000/???, loss: 6.333052635192871\n", "epoch: 0, update in batch 290000/???, loss: 5.595047950744629\n", "epoch: 0, update in batch 291000/???, loss: 6.594934940338135\n", "epoch: 0, update in batch 292000/???, loss: 5.950274467468262\n", "epoch: 0, update in batch 293000/???, loss: 6.123660087585449\n", "epoch: 0, update in batch 294000/???, loss: 5.904355049133301\n", "epoch: 0, update in batch 295000/???, loss: 5.8828630447387695\n", "epoch: 0, update in batch 296000/???, loss: 5.604973316192627\n", "epoch: 0, update in batch 297000/???, loss: 4.842469692230225\n", "epoch: 0, update in batch 298000/???, loss: 5.862446308135986\n", "epoch: 0, update in batch 299000/???, loss: 6.90258264541626\n", "epoch: 0, update in batch 300000/???, loss: 5.941957950592041\n", "epoch: 0, update in batch 301000/???, loss: 5.697750568389893\n", "epoch: 0, update in batch 302000/???, loss: 5.973014831542969\n", "epoch: 0, update in batch 303000/???, loss: 5.46022367477417\n", "epoch: 0, update in batch 304000/???, loss: 6.5218095779418945\n", "epoch: 0, update in batch 305000/???, loss: 6.392545700073242\n", "epoch: 0, update in batch 306000/???, loss: 7.080249786376953\n", "epoch: 0, update in batch 307000/???, loss: 6.355096817016602\n", "epoch: 0, update in batch 308000/???, loss: 5.625491619110107\n", "epoch: 0, update in batch 309000/???, loss: 6.805799961090088\n", "epoch: 0, update in batch 310000/???, loss: 6.426385402679443\n", "epoch: 0, update in batch 311000/???, loss: 5.727842807769775\n", "epoch: 0, update in batch 312000/???, loss: 6.9111199378967285\n", "epoch: 0, update in batch 313000/???, loss: 6.40056848526001\n", "epoch: 0, update in batch 314000/???, loss: 6.145076751708984\n", "epoch: 0, update in batch 315000/???, loss: 6.097104072570801\n", "epoch: 0, update in batch 316000/???, loss: 5.39146089553833\n", "epoch: 0, update in batch 317000/???, loss: 6.125569820404053\n", "epoch: 0, update in batch 318000/???, loss: 6.533677577972412\n", "epoch: 0, update in batch 319000/???, loss: 5.944211483001709\n", "epoch: 0, update in batch 320000/???, loss: 6.542410850524902\n", "epoch: 0, update in batch 321000/???, loss: 5.699315071105957\n", "epoch: 0, update in batch 322000/???, loss: 6.251957893371582\n", "epoch: 0, update in batch 323000/???, loss: 5.346350193023682\n", "epoch: 0, update in batch 324000/???, loss: 5.603858470916748\n", "epoch: 0, update in batch 325000/???, loss: 5.740134239196777\n", "epoch: 0, update in batch 326000/???, loss: 5.575300693511963\n", "epoch: 0, update in batch 327000/???, loss: 6.996762752532959\n", "epoch: 0, update in batch 328000/???, loss: 6.28995418548584\n", "epoch: 0, update in batch 329000/???, loss: 4.519123077392578\n", "epoch: 0, update in batch 330000/???, loss: 5.9068121910095215\n", "epoch: 0, update in batch 331000/???, loss: 6.61830997467041\n", "epoch: 0, update in batch 332000/???, loss: 6.063097953796387\n", "epoch: 0, update in batch 333000/???, loss: 6.419328212738037\n", "epoch: 0, update in batch 334000/???, loss: 5.927584648132324\n", "epoch: 0, update in batch 335000/???, loss: 5.527887344360352\n", "epoch: 0, update in batch 336000/???, loss: 6.114096641540527\n", "epoch: 0, update in batch 337000/???, loss: 5.9415082931518555\n", "epoch: 0, update in batch 338000/???, loss: 5.288441181182861\n", "epoch: 0, update in batch 339000/???, loss: 6.611715793609619\n", "epoch: 0, update in batch 340000/???, loss: 6.770573616027832\n" ] } ], "source": [ "train(train_dataset_back, model_back, 1, 64)" ] }, { "cell_type": "code", "execution_count": 30, "id": "36a1b802", "metadata": {}, "outputs": [], "source": [ "def predict_probs(left_tokens, right_tokens):\n", " model_front.eval()\n", " model_back.eval()\n", "\n", " x_left = torch.tensor([[train_dataset_front.key_to_index[w] if w in key_to_index else train_dataset_front.key_to_index[''] for w in left_tokens]]).to(device)\n", " x_right = torch.tensor([[train_dataset_front.key_to_index[w] if w in key_to_index else train_dataset_front.key_to_index[''] for w in right_tokens]]).to(device)\n", " y_pred_left, (state_h_left, state_c_left) = model_front(x_left)\n", " y_pred_right, (state_h_right, state_c_right) = model_back(x_right)\n", "\n", " last_word_logits_left = y_pred_left[0][-1]\n", " last_word_logits_right = y_pred_right[0][-1]\n", " probs_left = torch.nn.functional.softmax(last_word_logits_left, dim=0).detach().cpu().numpy()\n", " probs_right = torch.nn.functional.softmax(last_word_logits_right, dim=0).detach().cpu().numpy()\n", " \n", " probs = [np.mean(k) for k in zip(probs_left, probs_right)]\n", " \n", " top_words = []\n", " for index in range(len(probs)):\n", " if len(top_words) < 30:\n", " top_words.append((probs[index], [index]))\n", " else:\n", " worst_word = None\n", " for word in top_words:\n", " if not worst_word:\n", " worst_word = word\n", " else:\n", " if word[0] < worst_word[0]:\n", " worst_word = word\n", " if worst_word[0] < probs[index] and index != len(probs) - 1:\n", " top_words.remove(worst_word)\n", " top_words.append((probs[index], [index]))\n", " \n", " prediction = ''\n", " sum_prob = 0.0\n", " for word in top_words:\n", " sum_prob += word[0]\n", " word_index = word[0]\n", " word_text = index_to_key[word[1][0]]\n", " prediction += f'{word_text}:{word_index} '\n", " prediction += f':{1 - sum_prob}'\n", " \n", " return prediction" ] }, { "cell_type": "code", "execution_count": 16, "id": "155636b5", "metadata": {}, "outputs": [], "source": [ "dev_data = pd.read_csv('dev-0/in.tsv.xz', sep='\\t', on_bad_lines='skip', header=None, quoting=csv.QUOTE_NONE)\n", "test_data = pd.read_csv('test-A/in.tsv.xz', sep='\\t', on_bad_lines='skip', header=None, quoting=csv.QUOTE_NONE)" ] }, { "cell_type": "code", "execution_count": 39, "id": "99b1d944", "metadata": {}, "outputs": [], "source": [ "with open('dev-0/out.tsv', 'w') as file:\n", " for index, row in dev_data.iterrows():\n", " left_text = clean_text(str(row[6]))\n", " right_text = clean_text(str(row[7]))\n", " left_words = word_tokenize(left_text)\n", " right_words = word_tokenize(right_text)\n", " right_words.reverse()\n", " if len(left_words) < 6 or len(right_words) < 6:\n", " prediction = ':1.0'\n", " else:\n", " prediction = predict_probs(left_words[-5:], right_words[-5:])\n", " file.write(prediction + '\\n')" ] }, { "cell_type": "code", "execution_count": 41, "id": "186c3269", "metadata": {}, "outputs": [], "source": [ "with open('test-A/out.tsv', 'w') as file:\n", " for index, row in test_data.iterrows():\n", " left_text = clean_text(str(row[6]))\n", " right_text = clean_text(str(row[7]))\n", " left_words = word_tokenize(left_text)\n", " right_words = word_tokenize(right_text)\n", " right_words.reverse()\n", " if len(left_words) < 6 or len(right_words) < 6:\n", " prediction = ':1.0'\n", " else:\n", " prediction = predict_probs(left_words[-5:], right_words[-5:])\n", " file.write(prediction + '\\n')" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.2" } }, "nbformat": 4, "nbformat_minor": 5 }