challenging-america-word-ga.../run.ipynb

1107 lines
62 KiB
Plaintext
Raw Permalink Normal View History

2022-05-29 12:19:20 +02:00
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "03de852a",
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"import regex as re\n",
"import csv\n",
"import torch\n",
"from torch import nn\n",
"from gensim.models import Word2Vec\n",
"from nltk.tokenize import word_tokenize"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "73497953",
"metadata": {},
"outputs": [],
"source": [
"torch.cuda.empty_cache()\n",
"device = 'cuda' if torch.cuda.is_available() else 'cpu'"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "4227ef55",
"metadata": {},
"outputs": [],
"source": [
"def clean_text(text):\n",
" text = text.lower().replace('-\\\\\\\\\\\\\\\\n', '').replace('\\\\\\\\\\\\\\\\n', ' ')\n",
" text = re.sub(r'\\p{P}', '', text)\n",
" text = text.replace(\"'t\", \" not\").replace(\"'s\", \" is\").replace(\"'ll\", \" will\").replace(\"'m\", \" am\").replace(\"'ve\", \" have\")\n",
"\n",
" return text"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "758cf94a",
"metadata": {},
"outputs": [],
"source": [
"train_data = pd.read_csv('train/in.tsv.xz', sep='\\t', on_bad_lines='skip', header=None, quoting=csv.QUOTE_NONE)\n",
"train_labels = pd.read_csv('train/expected.tsv', sep='\\t', on_bad_lines='skip', header=None, quoting=csv.QUOTE_NONE)\n",
"\n",
"train_data = train_data[[6, 7]]\n",
"train_data = pd.concat([train_data, train_labels], axis=1)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "384922c7",
"metadata": {},
"outputs": [],
"source": [
"class TrainCorpus:\n",
" def __init__(self, data):\n",
" self.data = data\n",
" \n",
" def __iter__(self):\n",
" for _, row in self.data.iterrows():\n",
" text = str(row[6]) + str(row[0]) + str(row[7])\n",
" text = clean_text(text)\n",
" yield word_tokenize(text)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "3da5f19b",
"metadata": {},
"outputs": [],
"source": [
"train_sentences = TrainCorpus(train_data.head(80000))\n",
"w2v_model = Word2Vec(vector_size=100, min_count=10)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "183d43be",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"81477\n"
]
}
],
"source": [
"w2v_model.build_vocab(corpus_iterable=train_sentences)\n",
"\n",
"key_to_index = w2v_model.wv.key_to_index\n",
"index_to_key = w2v_model.wv.index_to_key\n",
"\n",
"index_to_key.append('<unk>')\n",
"key_to_index['<unk>'] = len(index_to_key) - 1\n",
"\n",
"vocab_size = len(index_to_key)\n",
"print(vocab_size)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "e63dd9fe",
"metadata": {},
"outputs": [],
"source": [
"class TrainDataset(torch.utils.data.IterableDataset):\n",
" def __init__(self, data, index_to_key, key_to_index, reversed=False):\n",
" self.reversed = reversed\n",
" self.data = data\n",
" self.index_to_key = index_to_key\n",
" self.key_to_index = key_to_index\n",
" self.vocab_size = len(key_to_index)\n",
"\n",
" def __iter__(self):\n",
" for _, row in self.data.iterrows():\n",
" text = str(row[6]) + str(row[0]) + str(row[7])\n",
" text = clean_text(text)\n",
" tokens = word_tokenize(text)\n",
" if self.reversed:\n",
" tokens = list(reversed(tokens))\n",
" for i in range(5, len(tokens), 1):\n",
" input_context = tokens[i-5:i]\n",
" target_context = tokens[i-4:i+1]\n",
" #gap_word = tokens[i]\n",
" \n",
" input_embed = [self.key_to_index[word] if word in self.key_to_index else self.key_to_index['<unk>'] for word in input_context]\n",
" target_embed = [self.key_to_index[word] if word in self.key_to_index else self.key_to_index['<unk>'] for word in target_context]\n",
" #word_index = self.key_to_index[gap_word] if gap_word in self.key_to_index else self.key_to_index['<unk>']\n",
" #word_embed = np.concatenate([np.zeros(word_index), np.ones(1), np.zeros(vocab_size - word_index - 1)])\n",
" \n",
" yield np.asarray(input_embed, dtype=np.int64), np.asarray(target_embed, dtype=np.int64)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "7c60ddc1",
"metadata": {},
"outputs": [],
"source": [
"class Model(nn.Module):\n",
" def __init__(self, embed_size, vocab_size):\n",
" super(Model, self).__init__()\n",
" self.embed_size = embed_size\n",
" self.vocab_size = vocab_size\n",
" self.lstm_size = 128\n",
" self.num_layers = 2\n",
" \n",
" self.embed = nn.Embedding(num_embeddings=vocab_size, embedding_dim=self.embed_size)\n",
" self.lstm = nn.LSTM(input_size=self.embed_size, hidden_size=self.lstm_size, num_layers=self.num_layers, dropout=0.2)\n",
" self.fc = nn.Linear(self.lstm_size, vocab_size)\n",
"\n",
" def forward(self, x, prev_state = None):\n",
" embed = self.embed(x)\n",
" output, state = self.lstm(embed, prev_state)\n",
" logits = self.fc(output)\n",
" probs = torch.softmax(logits, dim=1)\n",
" return logits, state\n",
"\n",
" def init_state(self, sequence_length):\n",
" zeros = torch.zeros(self.num_layers, sequence_length, self.gru_size).to(device)\n",
" return (zeros, zeros)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "1c7b8fab",
"metadata": {},
"outputs": [],
"source": [
"from torch.utils.data import DataLoader\n",
"from torch.optim import Adam\n",
"\n",
"def train(dataset, model, max_epochs, batch_size):\n",
" model.train()\n",
"\n",
" dataloader = DataLoader(dataset, batch_size=batch_size)\n",
" criterion = nn.CrossEntropyLoss()\n",
" optimizer = Adam(model.parameters(), lr=0.001)\n",
"\n",
" for epoch in range(max_epochs):\n",
" for batch, (x, y) in enumerate(dataloader):\n",
" optimizer.zero_grad()\n",
" \n",
" x = x.to(device)\n",
" y = y.to(device)\n",
" \n",
" y_pred, (state_h, state_c) = model(x)\n",
" loss = criterion(y_pred.transpose(1, 2), y)\n",
"\n",
" loss.backward()\n",
" optimizer.step()\n",
" \n",
" if batch % 1000 == 0:\n",
" print(f'epoch: {epoch}, update in batch {batch}/???, loss: {loss.item()}')"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "3531d21d",
"metadata": {},
"outputs": [],
"source": [
"train_dataset_front = TrainDataset(train_data.head(80000), index_to_key, key_to_index, False)\n",
"train_dataset_back = TrainDataset(train_data.tail(80000), index_to_key, key_to_index, True)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "f72f1f6d",
"metadata": {},
"outputs": [],
"source": [
"model_front = Model(100, vocab_size).to(device)\n",
"model_back = Model(100, vocab_size).to(device)"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "d608d9fe",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch: 0, update in batch 0/???, loss: 11.314821243286133\n",
"epoch: 0, update in batch 1000/???, loss: 6.876476287841797\n",
"epoch: 0, update in batch 2000/???, loss: 7.133523464202881\n",
"epoch: 0, update in batch 3000/???, loss: 6.979971885681152\n",
"epoch: 0, update in batch 4000/???, loss: 7.018368721008301\n",
"epoch: 0, update in batch 5000/???, loss: 6.494096279144287\n",
"epoch: 0, update in batch 6000/???, loss: 6.448479652404785\n",
"epoch: 0, update in batch 7000/???, loss: 6.526387691497803\n",
"epoch: 0, update in batch 8000/???, loss: 6.536323547363281\n",
"epoch: 0, update in batch 9000/???, loss: 6.4919538497924805\n",
"epoch: 0, update in batch 10000/???, loss: 6.435188293457031\n",
"epoch: 0, update in batch 11000/???, loss: 6.934823513031006\n",
"epoch: 0, update in batch 12000/???, loss: 7.410381317138672\n",
"epoch: 0, update in batch 13000/???, loss: 8.227864265441895\n",
"epoch: 0, update in batch 14000/???, loss: 6.7139105796813965\n",
"epoch: 0, update in batch 15000/???, loss: 6.82781457901001\n",
"epoch: 0, update in batch 16000/???, loss: 6.637822151184082\n",
"epoch: 0, update in batch 17000/???, loss: 6.2633233070373535\n",
"epoch: 0, update in batch 18000/???, loss: 6.512040138244629\n",
"epoch: 0, update in batch 19000/???, loss: 5.745478630065918\n",
"epoch: 0, update in batch 20000/???, loss: 7.039064884185791\n",
"epoch: 0, update in batch 21000/???, loss: 7.151158332824707\n",
"epoch: 0, update in batch 22000/???, loss: 6.460148811340332\n",
"epoch: 0, update in batch 23000/???, loss: 7.396632194519043\n",
"epoch: 0, update in batch 24000/???, loss: 5.907363414764404\n",
"epoch: 0, update in batch 25000/???, loss: 6.669890403747559\n",
"epoch: 0, update in batch 26000/???, loss: 6.032290458679199\n",
"epoch: 0, update in batch 27000/???, loss: 6.192468166351318\n",
"epoch: 0, update in batch 28000/???, loss: 5.757508277893066\n",
"epoch: 0, update in batch 29000/???, loss: 7.097552299499512\n",
"epoch: 0, update in batch 30000/???, loss: 6.8356804847717285\n",
"epoch: 0, update in batch 31000/???, loss: 4.938998699188232\n",
"epoch: 0, update in batch 32000/???, loss: 6.34550142288208\n",
"epoch: 0, update in batch 33000/???, loss: 7.154759883880615\n",
"epoch: 0, update in batch 34000/???, loss: 6.8563055992126465\n",
"epoch: 0, update in batch 35000/???, loss: 6.831148624420166\n",
"epoch: 0, update in batch 36000/???, loss: 6.867754936218262\n",
"epoch: 0, update in batch 37000/???, loss: 6.911463260650635\n",
"epoch: 0, update in batch 38000/???, loss: 6.637528896331787\n",
"epoch: 0, update in batch 39000/???, loss: 6.822340488433838\n",
"epoch: 0, update in batch 40000/???, loss: 6.122499942779541\n",
"epoch: 0, update in batch 41000/???, loss: 6.454296112060547\n",
"epoch: 0, update in batch 42000/???, loss: 7.5895185470581055\n",
"epoch: 0, update in batch 43000/???, loss: 5.775805473327637\n",
"epoch: 0, update in batch 44000/???, loss: 5.973118305206299\n",
"epoch: 0, update in batch 45000/???, loss: 5.7727460861206055\n",
"epoch: 0, update in batch 46000/???, loss: 6.376847267150879\n",
"epoch: 0, update in batch 47000/???, loss: 5.739894866943359\n",
"epoch: 0, update in batch 48000/???, loss: 6.390743732452393\n",
"epoch: 0, update in batch 49000/???, loss: 7.724233150482178\n",
"epoch: 0, update in batch 50000/???, loss: 5.242608070373535\n",
"epoch: 0, update in batch 51000/???, loss: 5.412053108215332\n",
"epoch: 0, update in batch 52000/???, loss: 6.590373992919922\n",
"epoch: 0, update in batch 53000/???, loss: 6.46323299407959\n",
"epoch: 0, update in batch 54000/???, loss: 6.9850263595581055\n",
"epoch: 0, update in batch 55000/???, loss: 7.3167219161987305\n",
"epoch: 0, update in batch 56000/???, loss: 6.285423278808594\n",
"epoch: 0, update in batch 57000/???, loss: 7.417998313903809\n",
"epoch: 0, update in batch 58000/???, loss: 6.437861442565918\n",
"epoch: 0, update in batch 59000/???, loss: 6.522177219390869\n",
"epoch: 0, update in batch 60000/???, loss: 5.9156928062438965\n",
"epoch: 0, update in batch 61000/???, loss: 4.946429252624512\n",
"epoch: 0, update in batch 62000/???, loss: 6.633675575256348\n",
"epoch: 0, update in batch 63000/???, loss: 7.357038974761963\n",
"epoch: 0, update in batch 64000/???, loss: 5.774768352508545\n",
"epoch: 0, update in batch 65000/???, loss: 6.289044380187988\n",
"epoch: 0, update in batch 66000/???, loss: 6.127488136291504\n",
"epoch: 0, update in batch 67000/???, loss: 5.059685230255127\n",
"epoch: 0, update in batch 68000/???, loss: 6.5439910888671875\n",
"epoch: 0, update in batch 69000/???, loss: 6.679286956787109\n",
"epoch: 0, update in batch 70000/???, loss: 7.2232346534729\n",
"epoch: 0, update in batch 71000/???, loss: 6.13685941696167\n",
"epoch: 0, update in batch 72000/???, loss: 5.766592025756836\n",
"epoch: 0, update in batch 73000/???, loss: 6.772070407867432\n",
"epoch: 0, update in batch 74000/???, loss: 7.369122505187988\n",
"epoch: 0, update in batch 75000/???, loss: 6.598935127258301\n",
"epoch: 0, update in batch 76000/???, loss: 5.948511600494385\n",
"epoch: 0, update in batch 77000/???, loss: 6.507765769958496\n",
"epoch: 0, update in batch 78000/???, loss: 5.09373664855957\n",
"epoch: 0, update in batch 79000/???, loss: 5.9862494468688965\n",
"epoch: 0, update in batch 80000/???, loss: 6.106108665466309\n",
"epoch: 0, update in batch 81000/???, loss: 5.2747578620910645\n",
"epoch: 0, update in batch 82000/???, loss: 6.324326515197754\n",
"epoch: 0, update in batch 83000/???, loss: 5.914392471313477\n",
"epoch: 0, update in batch 84000/???, loss: 6.641409873962402\n",
"epoch: 0, update in batch 85000/???, loss: 6.287321090698242\n",
"epoch: 0, update in batch 86000/???, loss: 6.510883331298828\n",
"epoch: 0, update in batch 87000/???, loss: 6.458550930023193\n",
"epoch: 0, update in batch 88000/???, loss: 6.07730770111084\n",
"epoch: 0, update in batch 89000/???, loss: 6.2387471199035645\n",
"epoch: 0, update in batch 90000/???, loss: 5.63344669342041\n",
"epoch: 0, update in batch 91000/???, loss: 6.277956962585449\n",
"epoch: 0, update in batch 92000/???, loss: 6.841054439544678\n",
"epoch: 0, update in batch 93000/???, loss: 6.458809852600098\n",
"epoch: 0, update in batch 94000/???, loss: 7.471741676330566\n",
"epoch: 0, update in batch 95000/???, loss: 6.461136817932129\n",
"epoch: 0, update in batch 96000/???, loss: 5.718675136566162\n",
"epoch: 0, update in batch 97000/???, loss: 4.4265007972717285\n",
"epoch: 0, update in batch 98000/???, loss: 7.05142879486084\n",
"epoch: 0, update in batch 99000/???, loss: 6.341854572296143\n",
"epoch: 0, update in batch 100000/???, loss: 6.834918022155762\n",
"epoch: 0, update in batch 101000/???, loss: 5.367598056793213\n",
"epoch: 0, update in batch 102000/???, loss: 5.716221809387207\n",
"epoch: 0, update in batch 103000/???, loss: 6.9465742111206055\n",
"epoch: 0, update in batch 104000/???, loss: 5.976019382476807\n",
"epoch: 0, update in batch 105000/???, loss: 6.125661849975586\n",
"epoch: 0, update in batch 106000/???, loss: 6.724229335784912\n",
"epoch: 0, update in batch 107000/???, loss: 6.446004390716553\n",
"epoch: 0, update in batch 108000/???, loss: 6.4710845947265625\n",
"epoch: 0, update in batch 109000/???, loss: 6.5926103591918945\n",
"epoch: 0, update in batch 110000/???, loss: 6.966839790344238\n",
"epoch: 0, update in batch 111000/???, loss: 7.263918876647949\n",
"epoch: 0, update in batch 112000/???, loss: 6.7561750411987305\n",
"epoch: 0, update in batch 113000/???, loss: 6.142555236816406\n",
"epoch: 0, update in batch 114000/???, loss: 5.974082946777344\n",
"epoch: 0, update in batch 115000/???, loss: 5.565796852111816\n",
"epoch: 0, update in batch 116000/???, loss: 6.4826202392578125\n",
"epoch: 0, update in batch 117000/???, loss: 5.643266201019287\n",
"epoch: 0, update in batch 118000/???, loss: 6.360909461975098\n",
"epoch: 0, update in batch 119000/???, loss: 5.4074201583862305\n",
"epoch: 0, update in batch 120000/???, loss: 7.1339569091796875\n",
"epoch: 0, update in batch 121000/???, loss: 6.786561012268066\n",
"epoch: 0, update in batch 122000/???, loss: 6.329574108123779\n",
"epoch: 0, update in batch 123000/???, loss: 7.21968936920166\n",
"epoch: 0, update in batch 124000/???, loss: 5.351359844207764\n",
"epoch: 0, update in batch 125000/???, loss: 7.962380886077881\n",
"epoch: 0, update in batch 126000/???, loss: 6.351782321929932\n",
"epoch: 0, update in batch 127000/???, loss: 6.8343048095703125\n",
"epoch: 0, update in batch 128000/???, loss: 6.129800319671631\n",
"epoch: 0, update in batch 129000/???, loss: 6.68627405166626\n",
"epoch: 0, update in batch 130000/???, loss: 6.498664855957031\n",
"epoch: 0, update in batch 131000/???, loss: 5.724549293518066\n",
"epoch: 0, update in batch 132000/???, loss: 7.041095733642578\n",
"epoch: 0, update in batch 133000/???, loss: 5.901988983154297\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch: 0, update in batch 134000/???, loss: 6.055495262145996\n",
"epoch: 0, update in batch 135000/???, loss: 6.363399982452393\n",
"epoch: 0, update in batch 136000/???, loss: 7.45733642578125\n",
"epoch: 0, update in batch 137000/???, loss: 6.960203647613525\n",
"epoch: 0, update in batch 138000/???, loss: 6.986503601074219\n",
"epoch: 0, update in batch 139000/???, loss: 5.7938127517700195\n",
"epoch: 0, update in batch 140000/???, loss: 5.559916019439697\n",
"epoch: 0, update in batch 141000/???, loss: 5.551616668701172\n",
"epoch: 0, update in batch 142000/???, loss: 5.386819839477539\n",
"epoch: 0, update in batch 143000/???, loss: 6.826618194580078\n",
"epoch: 0, update in batch 144000/???, loss: 6.106345176696777\n",
"epoch: 0, update in batch 145000/???, loss: 6.812024116516113\n",
"epoch: 0, update in batch 146000/???, loss: 6.347486972808838\n",
"epoch: 0, update in batch 147000/???, loss: 6.20189094543457\n",
"epoch: 0, update in batch 148000/???, loss: 5.5717034339904785\n",
"epoch: 0, update in batch 149000/???, loss: 6.884232521057129\n",
"epoch: 0, update in batch 150000/???, loss: 6.8074846267700195\n",
"epoch: 0, update in batch 151000/???, loss: 7.028794288635254\n",
"epoch: 0, update in batch 152000/???, loss: 5.201214790344238\n",
"epoch: 0, update in batch 153000/???, loss: 5.1864013671875\n",
"epoch: 0, update in batch 154000/???, loss: 6.4473114013671875\n",
"epoch: 0, update in batch 155000/???, loss: 4.9203643798828125\n",
"epoch: 0, update in batch 156000/???, loss: 6.829309940338135\n",
"epoch: 0, update in batch 157000/???, loss: 7.045801639556885\n",
"epoch: 0, update in batch 158000/???, loss: 6.4073967933654785\n",
"epoch: 0, update in batch 159000/???, loss: 6.494145393371582\n",
"epoch: 0, update in batch 160000/???, loss: 6.682474613189697\n",
"epoch: 0, update in batch 161000/???, loss: 5.125617980957031\n",
"epoch: 0, update in batch 162000/???, loss: 5.915367126464844\n",
"epoch: 0, update in batch 163000/???, loss: 6.4779157638549805\n",
"epoch: 0, update in batch 164000/???, loss: 5.547584533691406\n",
"epoch: 0, update in batch 165000/???, loss: 6.134579181671143\n",
"epoch: 0, update in batch 166000/???, loss: 5.300144672393799\n",
"epoch: 0, update in batch 167000/???, loss: 6.53488826751709\n",
"epoch: 0, update in batch 168000/???, loss: 6.711917877197266\n",
"epoch: 0, update in batch 169000/???, loss: 7.0150322914123535\n",
"epoch: 0, update in batch 170000/???, loss: 5.681846618652344\n",
"epoch: 0, update in batch 171000/???, loss: 6.583130836486816\n",
"epoch: 0, update in batch 172000/???, loss: 6.411820411682129\n",
"epoch: 0, update in batch 173000/???, loss: 5.725490093231201\n",
"epoch: 0, update in batch 174000/???, loss: 6.651374816894531\n",
"epoch: 0, update in batch 175000/???, loss: 5.800152778625488\n",
"epoch: 0, update in batch 176000/???, loss: 6.862998962402344\n",
"epoch: 0, update in batch 177000/???, loss: 6.668658256530762\n",
"epoch: 0, update in batch 178000/???, loss: 6.519270896911621\n",
"epoch: 0, update in batch 179000/???, loss: 6.716788291931152\n",
"epoch: 0, update in batch 180000/???, loss: 6.675846099853516\n",
"epoch: 0, update in batch 181000/???, loss: 6.598060607910156\n",
"epoch: 0, update in batch 182000/???, loss: 6.638599395751953\n",
"epoch: 0, update in batch 183000/???, loss: 5.693145275115967\n",
"epoch: 0, update in batch 184000/???, loss: 5.175653457641602\n",
"epoch: 0, update in batch 185000/???, loss: 6.659600734710693\n",
"epoch: 0, update in batch 186000/???, loss: 5.782421112060547\n",
"epoch: 0, update in batch 187000/???, loss: 6.1736297607421875\n",
"epoch: 0, update in batch 188000/???, loss: 5.38541316986084\n",
"epoch: 0, update in batch 189000/???, loss: 6.238187789916992\n",
"epoch: 0, update in batch 190000/???, loss: 6.10030460357666\n",
"epoch: 0, update in batch 191000/???, loss: 6.680960655212402\n",
"epoch: 0, update in batch 192000/???, loss: 6.600944519042969\n",
"epoch: 0, update in batch 193000/???, loss: 6.171700477600098\n",
"epoch: 0, update in batch 194000/???, loss: 7.250021934509277\n",
"epoch: 0, update in batch 195000/???, loss: 5.968771934509277\n",
"epoch: 0, update in batch 196000/???, loss: 7.107605934143066\n",
"epoch: 0, update in batch 197000/???, loss: 6.743283748626709\n",
"epoch: 0, update in batch 198000/???, loss: 7.130635738372803\n",
"epoch: 0, update in batch 199000/???, loss: 6.37470817565918\n",
"epoch: 0, update in batch 200000/???, loss: 6.050590515136719\n",
"epoch: 0, update in batch 201000/???, loss: 5.468177318572998\n",
"epoch: 0, update in batch 202000/???, loss: 6.343471527099609\n",
"epoch: 0, update in batch 203000/???, loss: 6.890538692474365\n",
"epoch: 0, update in batch 204000/???, loss: 7.018721580505371\n",
"epoch: 0, update in batch 205000/???, loss: 6.131939888000488\n",
"epoch: 0, update in batch 206000/???, loss: 6.219918251037598\n",
"epoch: 0, update in batch 207000/???, loss: 5.858460426330566\n",
"epoch: 0, update in batch 208000/???, loss: 6.33021354675293\n",
"epoch: 0, update in batch 209000/???, loss: 6.249329566955566\n",
"epoch: 0, update in batch 210000/???, loss: 6.263474941253662\n",
"epoch: 0, update in batch 211000/???, loss: 6.731234550476074\n",
"epoch: 0, update in batch 212000/???, loss: 5.978096961975098\n",
"epoch: 0, update in batch 213000/???, loss: 5.148629188537598\n",
"epoch: 0, update in batch 214000/???, loss: 6.79285192489624\n",
"epoch: 0, update in batch 215000/???, loss: 5.943106651306152\n",
"epoch: 0, update in batch 216000/???, loss: 5.749272346496582\n",
"epoch: 0, update in batch 217000/???, loss: 6.991009712219238\n",
"epoch: 0, update in batch 218000/???, loss: 6.21205997467041\n",
"epoch: 0, update in batch 219000/???, loss: 7.519427299499512\n",
"epoch: 0, update in batch 220000/???, loss: 5.699267387390137\n",
"epoch: 0, update in batch 221000/???, loss: 6.05304479598999\n",
"epoch: 0, update in batch 222000/???, loss: 6.422593116760254\n",
"epoch: 0, update in batch 223000/???, loss: 6.179877281188965\n",
"epoch: 0, update in batch 224000/???, loss: 4.841546058654785\n",
"epoch: 0, update in batch 225000/???, loss: 6.666176795959473\n",
"epoch: 0, update in batch 226000/???, loss: 5.994054794311523\n",
"epoch: 0, update in batch 227000/???, loss: 6.792928218841553\n",
"epoch: 0, update in batch 228000/???, loss: 6.9571661949157715\n",
"epoch: 0, update in batch 229000/???, loss: 6.198942184448242\n",
"epoch: 0, update in batch 230000/???, loss: 5.944539546966553\n",
"epoch: 0, update in batch 231000/???, loss: 6.188899040222168\n",
"epoch: 0, update in batch 232000/???, loss: 5.826596260070801\n",
"epoch: 0, update in batch 233000/???, loss: 5.728386878967285\n",
"epoch: 0, update in batch 234000/???, loss: 7.6024885177612305\n",
"epoch: 0, update in batch 235000/???, loss: 6.728615760803223\n",
"epoch: 0, update in batch 236000/???, loss: 6.2461137771606445\n",
"epoch: 0, update in batch 237000/???, loss: 6.3110551834106445\n",
"epoch: 0, update in batch 238000/???, loss: 6.12617826461792\n",
"epoch: 0, update in batch 239000/???, loss: 6.6068243980407715\n",
"epoch: 0, update in batch 240000/???, loss: 7.015429496765137\n",
"epoch: 0, update in batch 241000/???, loss: 8.444561004638672\n",
"epoch: 0, update in batch 242000/???, loss: 7.289303779602051\n",
"epoch: 0, update in batch 243000/???, loss: 6.260491371154785\n",
"epoch: 0, update in batch 244000/???, loss: 7.60237979888916\n",
"epoch: 0, update in batch 245000/???, loss: 6.295613765716553\n",
"epoch: 0, update in batch 246000/???, loss: 5.929107666015625\n",
"epoch: 0, update in batch 247000/???, loss: 5.835566997528076\n",
"epoch: 0, update in batch 248000/???, loss: 5.837784290313721\n",
"epoch: 0, update in batch 249000/???, loss: 5.972233772277832\n",
"epoch: 0, update in batch 250000/???, loss: 6.0488996505737305\n",
"epoch: 0, update in batch 251000/???, loss: 5.712280750274658\n",
"epoch: 0, update in batch 252000/???, loss: 5.9513702392578125\n",
"epoch: 0, update in batch 253000/???, loss: 5.636294364929199\n",
"epoch: 0, update in batch 254000/???, loss: 5.91803503036499\n",
"epoch: 0, update in batch 255000/???, loss: 7.285937309265137\n",
"epoch: 0, update in batch 256000/???, loss: 6.4795637130737305\n",
"epoch: 0, update in batch 257000/???, loss: 6.0709991455078125\n",
"epoch: 0, update in batch 258000/???, loss: 5.8723649978637695\n",
"epoch: 0, update in batch 259000/???, loss: 5.174002647399902\n",
"epoch: 0, update in batch 260000/???, loss: 6.504033088684082\n",
"epoch: 0, update in batch 261000/???, loss: 7.088961601257324\n",
"epoch: 0, update in batch 262000/???, loss: 6.2242960929870605\n",
"epoch: 0, update in batch 263000/???, loss: 5.970286846160889\n",
"epoch: 0, update in batch 264000/???, loss: 5.961676597595215\n",
"epoch: 0, update in batch 265000/???, loss: 6.170080661773682\n",
"epoch: 0, update in batch 266000/???, loss: 5.477972507476807\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch: 0, update in batch 267000/???, loss: 6.188825607299805\n",
"epoch: 0, update in batch 268000/???, loss: 6.518698215484619\n",
"epoch: 0, update in batch 269000/???, loss: 5.663434028625488\n",
"epoch: 0, update in batch 270000/???, loss: 5.978742599487305\n",
"epoch: 0, update in batch 271000/???, loss: 6.217379093170166\n",
"epoch: 0, update in batch 272000/???, loss: 5.426600933074951\n",
"epoch: 0, update in batch 273000/???, loss: 6.7220964431762695\n",
"epoch: 0, update in batch 274000/???, loss: 4.276306629180908\n",
"epoch: 0, update in batch 275000/???, loss: 5.420112609863281\n",
"epoch: 0, update in batch 276000/???, loss: 5.934456825256348\n",
"epoch: 0, update in batch 277000/???, loss: 7.186459541320801\n",
"epoch: 0, update in batch 278000/???, loss: 6.126835823059082\n",
"epoch: 0, update in batch 279000/???, loss: 5.727339267730713\n",
"epoch: 0, update in batch 280000/???, loss: 5.725864410400391\n",
"epoch: 0, update in batch 281000/???, loss: 5.47005033493042\n",
"epoch: 0, update in batch 282000/???, loss: 6.217499732971191\n",
"epoch: 0, update in batch 283000/???, loss: 6.022196292877197\n",
"epoch: 0, update in batch 284000/???, loss: 5.932379722595215\n",
"epoch: 0, update in batch 285000/???, loss: 6.321987628936768\n",
"epoch: 0, update in batch 286000/???, loss: 7.480570316314697\n",
"epoch: 0, update in batch 287000/???, loss: 5.169373512268066\n",
"epoch: 0, update in batch 288000/???, loss: 6.301320552825928\n",
"epoch: 0, update in batch 289000/???, loss: 6.4635009765625\n",
"epoch: 0, update in batch 290000/???, loss: 6.8701887130737305\n",
"epoch: 0, update in batch 291000/???, loss: 6.036175727844238\n",
"epoch: 0, update in batch 292000/???, loss: 6.705732822418213\n",
"epoch: 0, update in batch 293000/???, loss: 6.99608850479126\n",
"epoch: 0, update in batch 294000/???, loss: 6.50225305557251\n",
"epoch: 0, update in batch 295000/???, loss: 6.03929328918457\n",
"epoch: 0, update in batch 296000/???, loss: 5.498082160949707\n",
"epoch: 0, update in batch 297000/???, loss: 6.04677677154541\n",
"epoch: 0, update in batch 298000/???, loss: 6.482898712158203\n",
"epoch: 0, update in batch 299000/???, loss: 7.235076904296875\n",
"epoch: 0, update in batch 300000/???, loss: 6.019383907318115\n",
"epoch: 0, update in batch 301000/???, loss: 7.082001686096191\n",
"epoch: 0, update in batch 302000/???, loss: 6.447659492492676\n",
"epoch: 0, update in batch 303000/???, loss: 5.94022798538208\n",
"epoch: 0, update in batch 304000/???, loss: 6.459266662597656\n",
"epoch: 0, update in batch 305000/???, loss: 6.281588077545166\n",
"epoch: 0, update in batch 306000/???, loss: 7.022011756896973\n",
"epoch: 0, update in batch 307000/???, loss: 6.1802263259887695\n",
"epoch: 0, update in batch 308000/???, loss: 4.189492225646973\n",
"epoch: 0, update in batch 309000/???, loss: 6.7040696144104\n",
"epoch: 0, update in batch 310000/???, loss: 6.589522361755371\n",
"epoch: 0, update in batch 311000/???, loss: 6.243889808654785\n",
"epoch: 0, update in batch 312000/???, loss: 5.490180015563965\n",
"epoch: 0, update in batch 313000/???, loss: 5.9699201583862305\n",
"epoch: 0, update in batch 314000/???, loss: 7.321981906890869\n",
"epoch: 0, update in batch 315000/???, loss: 4.731215953826904\n",
"epoch: 0, update in batch 316000/???, loss: 5.845946788787842\n",
"epoch: 0, update in batch 317000/???, loss: 5.917788505554199\n",
"epoch: 0, update in batch 318000/???, loss: 6.420014381408691\n",
"epoch: 0, update in batch 319000/???, loss: 6.550830841064453\n",
"epoch: 0, update in batch 320000/???, loss: 6.751360893249512\n",
"epoch: 0, update in batch 321000/???, loss: 5.025134086608887\n",
"epoch: 0, update in batch 322000/???, loss: 6.368621826171875\n",
"epoch: 0, update in batch 323000/???, loss: 6.2042083740234375\n",
"epoch: 0, update in batch 324000/???, loss: 6.173147678375244\n",
"epoch: 0, update in batch 325000/???, loss: 5.865999221801758\n",
"epoch: 0, update in batch 326000/???, loss: 6.844902992248535\n",
"epoch: 0, update in batch 327000/???, loss: 6.080742359161377\n",
"epoch: 0, update in batch 328000/???, loss: 5.41788387298584\n",
"epoch: 0, update in batch 329000/???, loss: 5.831374645233154\n",
"epoch: 0, update in batch 330000/???, loss: 6.4492506980896\n",
"epoch: 0, update in batch 331000/???, loss: 6.220627784729004\n",
"epoch: 0, update in batch 332000/???, loss: 5.880006313323975\n",
"epoch: 0, update in batch 333000/???, loss: 6.806972503662109\n",
"epoch: 0, update in batch 334000/???, loss: 7.165728569030762\n",
"epoch: 0, update in batch 335000/???, loss: 6.322948932647705\n",
"epoch: 0, update in batch 336000/???, loss: 6.206046104431152\n",
"epoch: 0, update in batch 337000/???, loss: 6.097958564758301\n",
"epoch: 0, update in batch 338000/???, loss: 6.7682952880859375\n",
"epoch: 0, update in batch 339000/???, loss: 5.2390642166137695\n",
"epoch: 0, update in batch 340000/???, loss: 6.913119316101074\n"
]
}
],
"source": [
"train(train_dataset_front, model_front, 1, 64)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "132d9157",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch: 0, update in batch 0/???, loss: 11.3253755569458\n",
"epoch: 0, update in batch 1000/???, loss: 5.709358215332031\n",
"epoch: 0, update in batch 2000/???, loss: 7.989391326904297\n",
"epoch: 0, update in batch 3000/???, loss: 6.578714847564697\n",
"epoch: 0, update in batch 4000/???, loss: 7.051873207092285\n",
"epoch: 0, update in batch 5000/???, loss: 6.85653018951416\n",
"epoch: 0, update in batch 6000/???, loss: 6.812790870666504\n",
"epoch: 0, update in batch 7000/???, loss: 6.9604010581970215\n",
"epoch: 0, update in batch 8000/???, loss: 6.798591613769531\n",
"epoch: 0, update in batch 9000/???, loss: 6.415241241455078\n",
"epoch: 0, update in batch 10000/???, loss: 6.6636223793029785\n",
"epoch: 0, update in batch 11000/???, loss: 6.593747138977051\n",
"epoch: 0, update in batch 12000/???, loss: 6.914702415466309\n",
"epoch: 0, update in batch 13000/???, loss: 5.542675971984863\n",
"epoch: 0, update in batch 14000/???, loss: 6.5461883544921875\n",
"epoch: 0, update in batch 15000/???, loss: 7.507067680358887\n",
"epoch: 0, update in batch 16000/???, loss: 5.425755500793457\n",
"epoch: 0, update in batch 17000/???, loss: 6.285205841064453\n",
"epoch: 0, update in batch 18000/???, loss: 4.223124027252197\n",
"epoch: 0, update in batch 19000/???, loss: 6.530254364013672\n",
"epoch: 0, update in batch 20000/???, loss: 6.091847896575928\n",
"epoch: 0, update in batch 21000/???, loss: 7.088344573974609\n",
"epoch: 0, update in batch 22000/???, loss: 5.925537109375\n",
"epoch: 0, update in batch 23000/???, loss: 6.3628082275390625\n",
"epoch: 0, update in batch 24000/???, loss: 6.604581356048584\n",
"epoch: 0, update in batch 25000/???, loss: 6.2706499099731445\n",
"epoch: 0, update in batch 26000/???, loss: 6.114742755889893\n",
"epoch: 0, update in batch 27000/???, loss: 5.686783790588379\n",
"epoch: 0, update in batch 28000/???, loss: 5.5114521980285645\n",
"epoch: 0, update in batch 29000/???, loss: 6.999403953552246\n",
"epoch: 0, update in batch 30000/???, loss: 5.834499359130859\n",
"epoch: 0, update in batch 31000/???, loss: 5.873156547546387\n",
"epoch: 0, update in batch 32000/???, loss: 6.246962547302246\n",
"epoch: 0, update in batch 33000/???, loss: 6.742733955383301\n",
"epoch: 0, update in batch 34000/???, loss: 6.832881927490234\n",
"epoch: 0, update in batch 35000/???, loss: 6.625868320465088\n",
"epoch: 0, update in batch 36000/???, loss: 6.653105735778809\n",
"epoch: 0, update in batch 37000/???, loss: 6.104651927947998\n",
"epoch: 0, update in batch 38000/???, loss: 6.301898002624512\n",
"epoch: 0, update in batch 39000/???, loss: 7.377936363220215\n",
"epoch: 0, update in batch 40000/???, loss: 6.26895809173584\n",
"epoch: 0, update in batch 41000/???, loss: 6.602926731109619\n",
"epoch: 0, update in batch 42000/???, loss: 6.419803619384766\n",
"epoch: 0, update in batch 43000/???, loss: 7.187136650085449\n",
"epoch: 0, update in batch 44000/???, loss: 6.382015705108643\n",
"epoch: 0, update in batch 45000/???, loss: 6.044090747833252\n",
"epoch: 0, update in batch 46000/???, loss: 5.707688808441162\n",
"epoch: 0, update in batch 47000/???, loss: 7.007757663726807\n",
"epoch: 0, update in batch 48000/???, loss: 5.365390300750732\n",
"epoch: 0, update in batch 49000/???, loss: 5.510242938995361\n",
"epoch: 0, update in batch 50000/???, loss: 5.955991268157959\n",
"epoch: 0, update in batch 51000/???, loss: 6.2313032150268555\n",
"epoch: 0, update in batch 52000/???, loss: 8.19306468963623\n",
"epoch: 0, update in batch 53000/???, loss: 6.345375061035156\n",
"epoch: 0, update in batch 54000/???, loss: 7.044759273529053\n",
"epoch: 0, update in batch 55000/???, loss: 6.2544779777526855\n",
"epoch: 0, update in batch 56000/???, loss: 6.315605163574219\n",
"epoch: 0, update in batch 57000/???, loss: 5.632706642150879\n",
"epoch: 0, update in batch 58000/???, loss: 6.0897536277771\n",
"epoch: 0, update in batch 59000/???, loss: 5.562952518463135\n",
"epoch: 0, update in batch 60000/???, loss: 5.519134044647217\n",
"epoch: 0, update in batch 61000/???, loss: 6.394771099090576\n",
"epoch: 0, update in batch 62000/???, loss: 6.147246360778809\n",
"epoch: 0, update in batch 63000/???, loss: 5.798914909362793\n",
"epoch: 0, update in batch 64000/???, loss: 6.026059627532959\n",
"epoch: 0, update in batch 65000/???, loss: 6.4533233642578125\n",
"epoch: 0, update in batch 66000/???, loss: 6.383795738220215\n",
"epoch: 0, update in batch 67000/???, loss: 6.466322898864746\n",
"epoch: 0, update in batch 68000/???, loss: 6.8227715492248535\n",
"epoch: 0, update in batch 69000/???, loss: 6.283398151397705\n",
"epoch: 0, update in batch 70000/???, loss: 4.547608375549316\n",
"epoch: 0, update in batch 71000/???, loss: 6.008975028991699\n",
"epoch: 0, update in batch 72000/???, loss: 5.674825191497803\n",
"epoch: 0, update in batch 73000/???, loss: 5.134644508361816\n",
"epoch: 0, update in batch 74000/???, loss: 6.906868934631348\n",
"epoch: 0, update in batch 75000/???, loss: 6.672898292541504\n",
"epoch: 0, update in batch 76000/???, loss: 5.813290596008301\n",
"epoch: 0, update in batch 77000/???, loss: 6.296219825744629\n",
"epoch: 0, update in batch 78000/???, loss: 6.531443119049072\n",
"epoch: 0, update in batch 79000/???, loss: 6.437461853027344\n",
"epoch: 0, update in batch 80000/???, loss: 6.2280778884887695\n",
"epoch: 0, update in batch 81000/???, loss: 6.805241584777832\n",
"epoch: 0, update in batch 82000/???, loss: 7.044824123382568\n",
"epoch: 0, update in batch 83000/???, loss: 7.348274230957031\n",
"epoch: 0, update in batch 84000/???, loss: 5.826806545257568\n",
"epoch: 0, update in batch 85000/???, loss: 5.474950313568115\n",
"epoch: 0, update in batch 86000/???, loss: 6.497323036193848\n",
"epoch: 0, update in batch 87000/???, loss: 5.88934850692749\n",
"epoch: 0, update in batch 88000/???, loss: 5.371798038482666\n",
"epoch: 0, update in batch 89000/???, loss: 6.093968391418457\n",
"epoch: 0, update in batch 90000/???, loss: 6.115981578826904\n",
"epoch: 0, update in batch 91000/???, loss: 6.504927158355713\n",
"epoch: 0, update in batch 92000/???, loss: 6.239808082580566\n",
"epoch: 0, update in batch 93000/???, loss: 5.384994983673096\n",
"epoch: 0, update in batch 94000/???, loss: 6.422779083251953\n",
"epoch: 0, update in batch 95000/???, loss: 7.163965702056885\n",
"epoch: 0, update in batch 96000/???, loss: 6.44806432723999\n",
"epoch: 0, update in batch 97000/???, loss: 6.153664588928223\n",
"epoch: 0, update in batch 98000/???, loss: 5.9013776779174805\n",
"epoch: 0, update in batch 99000/???, loss: 6.198166847229004\n",
"epoch: 0, update in batch 100000/???, loss: 5.752341270446777\n",
"epoch: 0, update in batch 101000/???, loss: 6.455883979797363\n",
"epoch: 0, update in batch 102000/???, loss: 5.270313262939453\n",
"epoch: 0, update in batch 103000/???, loss: 6.475237846374512\n",
"epoch: 0, update in batch 104000/???, loss: 6.2444844245910645\n",
"epoch: 0, update in batch 105000/???, loss: 6.1563720703125\n",
"epoch: 0, update in batch 106000/???, loss: 6.12777853012085\n",
"epoch: 0, update in batch 107000/???, loss: 6.449145317077637\n",
"epoch: 0, update in batch 108000/???, loss: 6.515239715576172\n",
"epoch: 0, update in batch 109000/???, loss: 5.6317644119262695\n",
"epoch: 0, update in batch 110000/???, loss: 6.09606409072876\n",
"epoch: 0, update in batch 111000/???, loss: 7.069797515869141\n",
"epoch: 0, update in batch 112000/???, loss: 7.456076145172119\n",
"epoch: 0, update in batch 113000/???, loss: 6.668386936187744\n",
"epoch: 0, update in batch 114000/???, loss: 7.705430507659912\n",
"epoch: 0, update in batch 115000/???, loss: 6.983656883239746\n",
"epoch: 0, update in batch 116000/???, loss: 6.320417404174805\n",
"epoch: 0, update in batch 117000/???, loss: 7.184473991394043\n",
"epoch: 0, update in batch 118000/???, loss: 6.603268623352051\n",
"epoch: 0, update in batch 119000/???, loss: 6.670085906982422\n",
"epoch: 0, update in batch 120000/???, loss: 6.748586177825928\n",
"epoch: 0, update in batch 121000/???, loss: 6.353959560394287\n",
"epoch: 0, update in batch 122000/???, loss: 5.138751029968262\n",
"epoch: 0, update in batch 123000/???, loss: 6.507109642028809\n",
"epoch: 0, update in batch 124000/???, loss: 6.360246181488037\n",
"epoch: 0, update in batch 125000/???, loss: 7.164086818695068\n",
"epoch: 0, update in batch 126000/???, loss: 5.610747337341309\n",
"epoch: 0, update in batch 127000/???, loss: 5.066179275512695\n",
"epoch: 0, update in batch 128000/???, loss: 5.688697814941406\n",
"epoch: 0, update in batch 129000/???, loss: 6.960330963134766\n",
"epoch: 0, update in batch 130000/???, loss: 5.818534851074219\n",
"epoch: 0, update in batch 131000/???, loss: 6.186715602874756\n",
"epoch: 0, update in batch 132000/???, loss: 5.825492858886719\n",
"epoch: 0, update in batch 133000/???, loss: 5.576340675354004\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch: 0, update in batch 134000/???, loss: 5.503821849822998\n",
"epoch: 0, update in batch 135000/???, loss: 6.428965091705322\n",
"epoch: 0, update in batch 136000/???, loss: 5.102448463439941\n",
"epoch: 0, update in batch 137000/???, loss: 6.239314556121826\n",
"epoch: 0, update in batch 138000/???, loss: 6.028595447540283\n",
"epoch: 0, update in batch 139000/???, loss: 6.407244682312012\n",
"epoch: 0, update in batch 140000/???, loss: 5.597055912017822\n",
"epoch: 0, update in batch 141000/???, loss: 5.823704719543457\n",
"epoch: 0, update in batch 142000/???, loss: 6.665535926818848\n",
"epoch: 0, update in batch 143000/???, loss: 5.5736894607543945\n",
"epoch: 0, update in batch 144000/???, loss: 6.723180294036865\n",
"epoch: 0, update in batch 145000/???, loss: 6.378345489501953\n",
"epoch: 0, update in batch 146000/???, loss: 5.6936845779418945\n",
"epoch: 0, update in batch 147000/???, loss: 5.761658668518066\n",
"epoch: 0, update in batch 148000/???, loss: 5.580254077911377\n",
"epoch: 0, update in batch 149000/???, loss: 5.733176231384277\n",
"epoch: 0, update in batch 150000/???, loss: 6.901691436767578\n",
"epoch: 0, update in batch 151000/???, loss: 6.5111589431762695\n",
"epoch: 0, update in batch 152000/???, loss: 6.184727668762207\n",
"epoch: 0, update in batch 153000/???, loss: 7.407107353210449\n",
"epoch: 0, update in batch 154000/???, loss: 6.499199867248535\n",
"epoch: 0, update in batch 155000/???, loss: 5.143393516540527\n",
"epoch: 0, update in batch 156000/???, loss: 7.60940408706665\n",
"epoch: 0, update in batch 157000/???, loss: 6.766045570373535\n",
"epoch: 0, update in batch 158000/???, loss: 5.268759727478027\n",
"epoch: 0, update in batch 159000/???, loss: 7.558129787445068\n",
"epoch: 0, update in batch 160000/???, loss: 8.016000747680664\n",
"epoch: 0, update in batch 161000/???, loss: 5.959166526794434\n",
"epoch: 0, update in batch 162000/???, loss: 5.499085426330566\n",
"epoch: 0, update in batch 163000/???, loss: 6.581662654876709\n",
"epoch: 0, update in batch 164000/???, loss: 6.681334495544434\n",
"epoch: 0, update in batch 165000/???, loss: 7.817207336425781\n",
"epoch: 0, update in batch 166000/???, loss: 6.524381160736084\n",
"epoch: 0, update in batch 167000/???, loss: 5.903864860534668\n",
"epoch: 0, update in batch 168000/???, loss: 5.6087260246276855\n",
"epoch: 0, update in batch 169000/???, loss: 5.742824554443359\n",
"epoch: 0, update in batch 170000/???, loss: 6.129671096801758\n",
"epoch: 0, update in batch 171000/???, loss: 5.879034519195557\n",
"epoch: 0, update in batch 172000/???, loss: 6.322129249572754\n",
"epoch: 0, update in batch 173000/???, loss: 6.805352210998535\n",
"epoch: 0, update in batch 174000/???, loss: 7.162431240081787\n",
"epoch: 0, update in batch 175000/???, loss: 6.123959541320801\n",
"epoch: 0, update in batch 176000/???, loss: 7.544029235839844\n",
"epoch: 0, update in batch 177000/???, loss: 5.4254021644592285\n",
"epoch: 0, update in batch 178000/???, loss: 5.784268379211426\n",
"epoch: 0, update in batch 179000/???, loss: 5.8633856773376465\n",
"epoch: 0, update in batch 180000/???, loss: 6.556314945220947\n",
"epoch: 0, update in batch 181000/???, loss: 5.215446472167969\n",
"epoch: 0, update in batch 182000/???, loss: 6.079234600067139\n",
"epoch: 0, update in batch 183000/???, loss: 7.234827995300293\n",
"epoch: 0, update in batch 184000/???, loss: 5.249889373779297\n",
"epoch: 0, update in batch 185000/???, loss: 5.083311080932617\n",
"epoch: 0, update in batch 186000/???, loss: 6.061867713928223\n",
"epoch: 0, update in batch 187000/???, loss: 6.060431480407715\n",
"epoch: 0, update in batch 188000/???, loss: 5.572680950164795\n",
"epoch: 0, update in batch 189000/???, loss: 5.991988182067871\n",
"epoch: 0, update in batch 190000/???, loss: 6.521245002746582\n",
"epoch: 0, update in batch 191000/???, loss: 5.128615379333496\n",
"epoch: 0, update in batch 192000/???, loss: 5.616750717163086\n",
"epoch: 0, update in batch 193000/???, loss: 6.1465044021606445\n",
"epoch: 0, update in batch 194000/???, loss: 5.93985652923584\n",
"epoch: 0, update in batch 195000/???, loss: 6.268892765045166\n",
"epoch: 0, update in batch 196000/???, loss: 5.928576469421387\n",
"epoch: 0, update in batch 197000/???, loss: 5.257290363311768\n",
"epoch: 0, update in batch 198000/???, loss: 6.6432952880859375\n",
"epoch: 0, update in batch 199000/???, loss: 6.898074150085449\n",
"epoch: 0, update in batch 200000/???, loss: 7.042447566986084\n",
"epoch: 0, update in batch 201000/???, loss: 7.104043483734131\n",
"epoch: 0, update in batch 202000/???, loss: 6.238812446594238\n",
"epoch: 0, update in batch 203000/???, loss: 6.773525238037109\n",
"epoch: 0, update in batch 204000/???, loss: 5.054592132568359\n",
"epoch: 0, update in batch 205000/???, loss: 6.854428768157959\n",
"epoch: 0, update in batch 206000/???, loss: 5.9983601570129395\n",
"epoch: 0, update in batch 207000/???, loss: 5.236695766448975\n",
"epoch: 0, update in batch 208000/???, loss: 6.086891174316406\n",
"epoch: 0, update in batch 209000/???, loss: 6.134495258331299\n",
"epoch: 0, update in batch 210000/???, loss: 6.52248477935791\n",
"epoch: 0, update in batch 211000/???, loss: 6.028376579284668\n",
"epoch: 0, update in batch 212000/???, loss: 6.140281677246094\n",
"epoch: 0, update in batch 213000/???, loss: 6.066422462463379\n",
"epoch: 0, update in batch 214000/???, loss: 6.868189334869385\n",
"epoch: 0, update in batch 215000/???, loss: 6.641358852386475\n",
"epoch: 0, update in batch 216000/???, loss: 6.818638801574707\n",
"epoch: 0, update in batch 217000/???, loss: 6.40252685546875\n",
"epoch: 0, update in batch 218000/???, loss: 5.561617851257324\n",
"epoch: 0, update in batch 219000/???, loss: 6.434267997741699\n",
"epoch: 0, update in batch 220000/???, loss: 6.33272123336792\n",
"epoch: 0, update in batch 221000/???, loss: 5.75616979598999\n",
"epoch: 0, update in batch 222000/???, loss: 6.477814674377441\n",
"epoch: 0, update in batch 223000/???, loss: 5.259497165679932\n",
"epoch: 0, update in batch 224000/???, loss: 5.8639655113220215\n",
"epoch: 0, update in batch 225000/???, loss: 6.469706058502197\n",
"epoch: 0, update in batch 226000/???, loss: 5.707249164581299\n",
"epoch: 0, update in batch 227000/???, loss: 6.394181251525879\n",
"epoch: 0, update in batch 228000/???, loss: 5.048886299133301\n",
"epoch: 0, update in batch 229000/???, loss: 5.842928409576416\n",
"epoch: 0, update in batch 230000/???, loss: 5.627688407897949\n",
"epoch: 0, update in batch 231000/???, loss: 7.950299263000488\n",
"epoch: 0, update in batch 232000/???, loss: 6.771368503570557\n",
"epoch: 0, update in batch 233000/???, loss: 5.787235260009766\n",
"epoch: 0, update in batch 234000/???, loss: 5.6070780754089355\n",
"epoch: 0, update in batch 235000/???, loss: 6.060035705566406\n",
"epoch: 0, update in batch 236000/???, loss: 6.894829750061035\n",
"epoch: 0, update in batch 237000/???, loss: 5.672856330871582\n",
"epoch: 0, update in batch 238000/???, loss: 5.054213523864746\n",
"epoch: 0, update in batch 239000/???, loss: 6.484643459320068\n",
"epoch: 0, update in batch 240000/???, loss: 5.800728797912598\n",
"epoch: 0, update in batch 241000/???, loss: 5.148013591766357\n",
"epoch: 0, update in batch 242000/???, loss: 5.529184818267822\n",
"epoch: 0, update in batch 243000/???, loss: 5.959448337554932\n",
"epoch: 0, update in batch 244000/???, loss: 6.762448787689209\n",
"epoch: 0, update in batch 245000/???, loss: 4.907589912414551\n",
"epoch: 0, update in batch 246000/???, loss: 6.275182723999023\n",
"epoch: 0, update in batch 247000/???, loss: 5.7234015464782715\n",
"epoch: 0, update in batch 248000/???, loss: 6.119207859039307\n",
"epoch: 0, update in batch 249000/???, loss: 5.297057151794434\n",
"epoch: 0, update in batch 250000/???, loss: 5.924614906311035\n",
"epoch: 0, update in batch 251000/???, loss: 6.651083469390869\n",
"epoch: 0, update in batch 252000/???, loss: 5.7164201736450195\n",
"epoch: 0, update in batch 253000/???, loss: 6.105191230773926\n",
"epoch: 0, update in batch 254000/???, loss: 5.791018486022949\n",
"epoch: 0, update in batch 255000/???, loss: 6.659502983093262\n",
"epoch: 0, update in batch 256000/???, loss: 5.613073348999023\n",
"epoch: 0, update in batch 257000/???, loss: 7.501049041748047\n",
"epoch: 0, update in batch 258000/???, loss: 6.043797492980957\n",
"epoch: 0, update in batch 259000/???, loss: 7.3587327003479\n",
"epoch: 0, update in batch 260000/???, loss: 6.276612281799316\n",
"epoch: 0, update in batch 261000/???, loss: 6.445192813873291\n",
"epoch: 0, update in batch 262000/???, loss: 5.0266547203063965\n",
"epoch: 0, update in batch 263000/???, loss: 6.404935359954834\n",
"epoch: 0, update in batch 264000/???, loss: 6.5042290687561035\n",
"epoch: 0, update in batch 265000/???, loss: 6.880773067474365\n",
"epoch: 0, update in batch 266000/???, loss: 6.3690643310546875\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"epoch: 0, update in batch 267000/???, loss: 6.055562973022461\n",
"epoch: 0, update in batch 268000/???, loss: 5.796906471252441\n",
"epoch: 0, update in batch 269000/???, loss: 5.654962539672852\n",
"epoch: 0, update in batch 270000/???, loss: 6.574362277984619\n",
"epoch: 0, update in batch 271000/???, loss: 6.256768226623535\n",
"epoch: 0, update in batch 272000/???, loss: 6.8345208168029785\n",
"epoch: 0, update in batch 273000/???, loss: 6.066469669342041\n",
"epoch: 0, update in batch 274000/???, loss: 6.625809669494629\n",
"epoch: 0, update in batch 275000/???, loss: 4.762896537780762\n",
"epoch: 0, update in batch 276000/???, loss: 6.019833564758301\n",
"epoch: 0, update in batch 277000/???, loss: 6.227939605712891\n",
"epoch: 0, update in batch 278000/???, loss: 7.046879768371582\n",
"epoch: 0, update in batch 279000/???, loss: 6.068551540374756\n",
"epoch: 0, update in batch 280000/???, loss: 6.454771995544434\n",
"epoch: 0, update in batch 281000/???, loss: 3.9379985332489014\n",
"epoch: 0, update in batch 282000/???, loss: 5.615240097045898\n",
"epoch: 0, update in batch 283000/???, loss: 5.7963151931762695\n",
"epoch: 0, update in batch 284000/???, loss: 6.064437389373779\n",
"epoch: 0, update in batch 285000/???, loss: 6.668734073638916\n",
"epoch: 0, update in batch 286000/???, loss: 6.776829719543457\n",
"epoch: 0, update in batch 287000/???, loss: 6.170516014099121\n",
"epoch: 0, update in batch 288000/???, loss: 4.840399742126465\n",
"epoch: 0, update in batch 289000/???, loss: 6.333052635192871\n",
"epoch: 0, update in batch 290000/???, loss: 5.595047950744629\n",
"epoch: 0, update in batch 291000/???, loss: 6.594934940338135\n",
"epoch: 0, update in batch 292000/???, loss: 5.950274467468262\n",
"epoch: 0, update in batch 293000/???, loss: 6.123660087585449\n",
"epoch: 0, update in batch 294000/???, loss: 5.904355049133301\n",
"epoch: 0, update in batch 295000/???, loss: 5.8828630447387695\n",
"epoch: 0, update in batch 296000/???, loss: 5.604973316192627\n",
"epoch: 0, update in batch 297000/???, loss: 4.842469692230225\n",
"epoch: 0, update in batch 298000/???, loss: 5.862446308135986\n",
"epoch: 0, update in batch 299000/???, loss: 6.90258264541626\n",
"epoch: 0, update in batch 300000/???, loss: 5.941957950592041\n",
"epoch: 0, update in batch 301000/???, loss: 5.697750568389893\n",
"epoch: 0, update in batch 302000/???, loss: 5.973014831542969\n",
"epoch: 0, update in batch 303000/???, loss: 5.46022367477417\n",
"epoch: 0, update in batch 304000/???, loss: 6.5218095779418945\n",
"epoch: 0, update in batch 305000/???, loss: 6.392545700073242\n",
"epoch: 0, update in batch 306000/???, loss: 7.080249786376953\n",
"epoch: 0, update in batch 307000/???, loss: 6.355096817016602\n",
"epoch: 0, update in batch 308000/???, loss: 5.625491619110107\n",
"epoch: 0, update in batch 309000/???, loss: 6.805799961090088\n",
"epoch: 0, update in batch 310000/???, loss: 6.426385402679443\n",
"epoch: 0, update in batch 311000/???, loss: 5.727842807769775\n",
"epoch: 0, update in batch 312000/???, loss: 6.9111199378967285\n",
"epoch: 0, update in batch 313000/???, loss: 6.40056848526001\n",
"epoch: 0, update in batch 314000/???, loss: 6.145076751708984\n",
"epoch: 0, update in batch 315000/???, loss: 6.097104072570801\n",
"epoch: 0, update in batch 316000/???, loss: 5.39146089553833\n",
"epoch: 0, update in batch 317000/???, loss: 6.125569820404053\n",
"epoch: 0, update in batch 318000/???, loss: 6.533677577972412\n",
"epoch: 0, update in batch 319000/???, loss: 5.944211483001709\n",
"epoch: 0, update in batch 320000/???, loss: 6.542410850524902\n",
"epoch: 0, update in batch 321000/???, loss: 5.699315071105957\n",
"epoch: 0, update in batch 322000/???, loss: 6.251957893371582\n",
"epoch: 0, update in batch 323000/???, loss: 5.346350193023682\n",
"epoch: 0, update in batch 324000/???, loss: 5.603858470916748\n",
"epoch: 0, update in batch 325000/???, loss: 5.740134239196777\n",
"epoch: 0, update in batch 326000/???, loss: 5.575300693511963\n",
"epoch: 0, update in batch 327000/???, loss: 6.996762752532959\n",
"epoch: 0, update in batch 328000/???, loss: 6.28995418548584\n",
"epoch: 0, update in batch 329000/???, loss: 4.519123077392578\n",
"epoch: 0, update in batch 330000/???, loss: 5.9068121910095215\n",
"epoch: 0, update in batch 331000/???, loss: 6.61830997467041\n",
"epoch: 0, update in batch 332000/???, loss: 6.063097953796387\n",
"epoch: 0, update in batch 333000/???, loss: 6.419328212738037\n",
"epoch: 0, update in batch 334000/???, loss: 5.927584648132324\n",
"epoch: 0, update in batch 335000/???, loss: 5.527887344360352\n",
"epoch: 0, update in batch 336000/???, loss: 6.114096641540527\n",
"epoch: 0, update in batch 337000/???, loss: 5.9415082931518555\n",
"epoch: 0, update in batch 338000/???, loss: 5.288441181182861\n",
"epoch: 0, update in batch 339000/???, loss: 6.611715793609619\n",
"epoch: 0, update in batch 340000/???, loss: 6.770573616027832\n"
]
}
],
"source": [
"train(train_dataset_back, model_back, 1, 64)"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "36a1b802",
"metadata": {},
"outputs": [],
"source": [
"def predict_probs(left_tokens, right_tokens):\n",
" model_front.eval()\n",
" model_back.eval()\n",
"\n",
" x_left = torch.tensor([[train_dataset_front.key_to_index[w] if w in key_to_index else train_dataset_front.key_to_index['<unk>'] for w in left_tokens]]).to(device)\n",
" x_right = torch.tensor([[train_dataset_front.key_to_index[w] if w in key_to_index else train_dataset_front.key_to_index['<unk>'] for w in right_tokens]]).to(device)\n",
" y_pred_left, (state_h_left, state_c_left) = model_front(x_left)\n",
" y_pred_right, (state_h_right, state_c_right) = model_back(x_right)\n",
"\n",
" last_word_logits_left = y_pred_left[0][-1]\n",
" last_word_logits_right = y_pred_right[0][-1]\n",
" probs_left = torch.nn.functional.softmax(last_word_logits_left, dim=0).detach().cpu().numpy()\n",
" probs_right = torch.nn.functional.softmax(last_word_logits_right, dim=0).detach().cpu().numpy()\n",
" \n",
" probs = [np.mean(k) for k in zip(probs_left, probs_right)]\n",
" \n",
" top_words = []\n",
" for index in range(len(probs)):\n",
" if len(top_words) < 30:\n",
" top_words.append((probs[index], [index]))\n",
" else:\n",
" worst_word = None\n",
" for word in top_words:\n",
" if not worst_word:\n",
" worst_word = word\n",
" else:\n",
" if word[0] < worst_word[0]:\n",
" worst_word = word\n",
" if worst_word[0] < probs[index] and index != len(probs) - 1:\n",
" top_words.remove(worst_word)\n",
" top_words.append((probs[index], [index]))\n",
" \n",
" prediction = ''\n",
" sum_prob = 0.0\n",
" for word in top_words:\n",
" sum_prob += word[0]\n",
" word_index = word[0]\n",
" word_text = index_to_key[word[1][0]]\n",
" prediction += f'{word_text}:{word_index} '\n",
" prediction += f':{1 - sum_prob}'\n",
" \n",
" return prediction"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "155636b5",
"metadata": {},
"outputs": [],
"source": [
"dev_data = pd.read_csv('dev-0/in.tsv.xz', sep='\\t', on_bad_lines='skip', header=None, quoting=csv.QUOTE_NONE)\n",
"test_data = pd.read_csv('test-A/in.tsv.xz', sep='\\t', on_bad_lines='skip', header=None, quoting=csv.QUOTE_NONE)"
]
},
{
"cell_type": "code",
"execution_count": 39,
"id": "99b1d944",
"metadata": {},
"outputs": [],
"source": [
"with open('dev-0/out.tsv', 'w') as file:\n",
" for index, row in dev_data.iterrows():\n",
" left_text = clean_text(str(row[6]))\n",
" right_text = clean_text(str(row[7]))\n",
" left_words = word_tokenize(left_text)\n",
" right_words = word_tokenize(right_text)\n",
" right_words.reverse()\n",
" if len(left_words) < 6 or len(right_words) < 6:\n",
" prediction = ':1.0'\n",
" else:\n",
" prediction = predict_probs(left_words[-5:], right_words[-5:])\n",
" file.write(prediction + '\\n')"
]
},
{
"cell_type": "code",
"execution_count": 41,
"id": "186c3269",
"metadata": {},
"outputs": [],
"source": [
"with open('test-A/out.tsv', 'w') as file:\n",
" for index, row in test_data.iterrows():\n",
" left_text = clean_text(str(row[6]))\n",
" right_text = clean_text(str(row[7]))\n",
" left_words = word_tokenize(left_text)\n",
" right_words = word_tokenize(right_text)\n",
" right_words.reverse()\n",
" if len(left_words) < 6 or len(right_words) < 6:\n",
" prediction = ':1.0'\n",
" else:\n",
" prediction = predict_probs(left_words[-5:], right_words[-5:])\n",
" file.write(prediction + '\\n')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.2"
}
},
"nbformat": 4,
"nbformat_minor": 5
}