62 KiB
62 KiB
import pandas as pd
import numpy as np
import regex as re
import csv
import torch
from torch import nn
from gensim.models import Word2Vec
from nltk.tokenize import word_tokenize
torch.cuda.empty_cache()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
def clean_text(text):
text = text.lower().replace('-\\\\\\\\\\\\\\\\n', '').replace('\\\\\\\\\\\\\\\\n', ' ')
text = re.sub(r'\p{P}', '', text)
text = text.replace("'t", " not").replace("'s", " is").replace("'ll", " will").replace("'m", " am").replace("'ve", " have")
return text
train_data = pd.read_csv('train/in.tsv.xz', sep='\t', on_bad_lines='skip', header=None, quoting=csv.QUOTE_NONE)
train_labels = pd.read_csv('train/expected.tsv', sep='\t', on_bad_lines='skip', header=None, quoting=csv.QUOTE_NONE)
train_data = train_data[[6, 7]]
train_data = pd.concat([train_data, train_labels], axis=1)
class TrainCorpus:
def __init__(self, data):
self.data = data
def __iter__(self):
for _, row in self.data.iterrows():
text = str(row[6]) + str(row[0]) + str(row[7])
text = clean_text(text)
yield word_tokenize(text)
train_sentences = TrainCorpus(train_data.head(80000))
w2v_model = Word2Vec(vector_size=100, min_count=10)
w2v_model.build_vocab(corpus_iterable=train_sentences)
key_to_index = w2v_model.wv.key_to_index
index_to_key = w2v_model.wv.index_to_key
index_to_key.append('<unk>')
key_to_index['<unk>'] = len(index_to_key) - 1
vocab_size = len(index_to_key)
print(vocab_size)
81477
class TrainDataset(torch.utils.data.IterableDataset):
def __init__(self, data, index_to_key, key_to_index, reversed=False):
self.reversed = reversed
self.data = data
self.index_to_key = index_to_key
self.key_to_index = key_to_index
self.vocab_size = len(key_to_index)
def __iter__(self):
for _, row in self.data.iterrows():
text = str(row[6]) + str(row[0]) + str(row[7])
text = clean_text(text)
tokens = word_tokenize(text)
if self.reversed:
tokens = list(reversed(tokens))
for i in range(5, len(tokens), 1):
input_context = tokens[i-5:i]
target_context = tokens[i-4:i+1]
#gap_word = tokens[i]
input_embed = [self.key_to_index[word] if word in self.key_to_index else self.key_to_index['<unk>'] for word in input_context]
target_embed = [self.key_to_index[word] if word in self.key_to_index else self.key_to_index['<unk>'] for word in target_context]
#word_index = self.key_to_index[gap_word] if gap_word in self.key_to_index else self.key_to_index['<unk>']
#word_embed = np.concatenate([np.zeros(word_index), np.ones(1), np.zeros(vocab_size - word_index - 1)])
yield np.asarray(input_embed, dtype=np.int64), np.asarray(target_embed, dtype=np.int64)
class Model(nn.Module):
def __init__(self, embed_size, vocab_size):
super(Model, self).__init__()
self.embed_size = embed_size
self.vocab_size = vocab_size
self.lstm_size = 128
self.num_layers = 2
self.embed = nn.Embedding(num_embeddings=vocab_size, embedding_dim=self.embed_size)
self.lstm = nn.LSTM(input_size=self.embed_size, hidden_size=self.lstm_size, num_layers=self.num_layers, dropout=0.2)
self.fc = nn.Linear(self.lstm_size, vocab_size)
def forward(self, x, prev_state = None):
embed = self.embed(x)
output, state = self.lstm(embed, prev_state)
logits = self.fc(output)
probs = torch.softmax(logits, dim=1)
return logits, state
def init_state(self, sequence_length):
zeros = torch.zeros(self.num_layers, sequence_length, self.gru_size).to(device)
return (zeros, zeros)
from torch.utils.data import DataLoader
from torch.optim import Adam
def train(dataset, model, max_epochs, batch_size):
model.train()
dataloader = DataLoader(dataset, batch_size=batch_size)
criterion = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=0.001)
for epoch in range(max_epochs):
for batch, (x, y) in enumerate(dataloader):
optimizer.zero_grad()
x = x.to(device)
y = y.to(device)
y_pred, (state_h, state_c) = model(x)
loss = criterion(y_pred.transpose(1, 2), y)
loss.backward()
optimizer.step()
if batch % 1000 == 0:
print(f'epoch: {epoch}, update in batch {batch}/???, loss: {loss.item()}')
train_dataset_front = TrainDataset(train_data.head(80000), index_to_key, key_to_index, False)
train_dataset_back = TrainDataset(train_data.tail(80000), index_to_key, key_to_index, True)
model_front = Model(100, vocab_size).to(device)
model_back = Model(100, vocab_size).to(device)
train(train_dataset_front, model_front, 1, 64)
epoch: 0, update in batch 0/???, loss: 11.314821243286133 epoch: 0, update in batch 1000/???, loss: 6.876476287841797 epoch: 0, update in batch 2000/???, loss: 7.133523464202881 epoch: 0, update in batch 3000/???, loss: 6.979971885681152 epoch: 0, update in batch 4000/???, loss: 7.018368721008301 epoch: 0, update in batch 5000/???, loss: 6.494096279144287 epoch: 0, update in batch 6000/???, loss: 6.448479652404785 epoch: 0, update in batch 7000/???, loss: 6.526387691497803 epoch: 0, update in batch 8000/???, loss: 6.536323547363281 epoch: 0, update in batch 9000/???, loss: 6.4919538497924805 epoch: 0, update in batch 10000/???, loss: 6.435188293457031 epoch: 0, update in batch 11000/???, loss: 6.934823513031006 epoch: 0, update in batch 12000/???, loss: 7.410381317138672 epoch: 0, update in batch 13000/???, loss: 8.227864265441895 epoch: 0, update in batch 14000/???, loss: 6.7139105796813965 epoch: 0, update in batch 15000/???, loss: 6.82781457901001 epoch: 0, update in batch 16000/???, loss: 6.637822151184082 epoch: 0, update in batch 17000/???, loss: 6.2633233070373535 epoch: 0, update in batch 18000/???, loss: 6.512040138244629 epoch: 0, update in batch 19000/???, loss: 5.745478630065918 epoch: 0, update in batch 20000/???, loss: 7.039064884185791 epoch: 0, update in batch 21000/???, loss: 7.151158332824707 epoch: 0, update in batch 22000/???, loss: 6.460148811340332 epoch: 0, update in batch 23000/???, loss: 7.396632194519043 epoch: 0, update in batch 24000/???, loss: 5.907363414764404 epoch: 0, update in batch 25000/???, loss: 6.669890403747559 epoch: 0, update in batch 26000/???, loss: 6.032290458679199 epoch: 0, update in batch 27000/???, loss: 6.192468166351318 epoch: 0, update in batch 28000/???, loss: 5.757508277893066 epoch: 0, update in batch 29000/???, loss: 7.097552299499512 epoch: 0, update in batch 30000/???, loss: 6.8356804847717285 epoch: 0, update in batch 31000/???, loss: 4.938998699188232 epoch: 0, update in batch 32000/???, loss: 6.34550142288208 epoch: 0, update in batch 33000/???, loss: 7.154759883880615 epoch: 0, update in batch 34000/???, loss: 6.8563055992126465 epoch: 0, update in batch 35000/???, loss: 6.831148624420166 epoch: 0, update in batch 36000/???, loss: 6.867754936218262 epoch: 0, update in batch 37000/???, loss: 6.911463260650635 epoch: 0, update in batch 38000/???, loss: 6.637528896331787 epoch: 0, update in batch 39000/???, loss: 6.822340488433838 epoch: 0, update in batch 40000/???, loss: 6.122499942779541 epoch: 0, update in batch 41000/???, loss: 6.454296112060547 epoch: 0, update in batch 42000/???, loss: 7.5895185470581055 epoch: 0, update in batch 43000/???, loss: 5.775805473327637 epoch: 0, update in batch 44000/???, loss: 5.973118305206299 epoch: 0, update in batch 45000/???, loss: 5.7727460861206055 epoch: 0, update in batch 46000/???, loss: 6.376847267150879 epoch: 0, update in batch 47000/???, loss: 5.739894866943359 epoch: 0, update in batch 48000/???, loss: 6.390743732452393 epoch: 0, update in batch 49000/???, loss: 7.724233150482178 epoch: 0, update in batch 50000/???, loss: 5.242608070373535 epoch: 0, update in batch 51000/???, loss: 5.412053108215332 epoch: 0, update in batch 52000/???, loss: 6.590373992919922 epoch: 0, update in batch 53000/???, loss: 6.46323299407959 epoch: 0, update in batch 54000/???, loss: 6.9850263595581055 epoch: 0, update in batch 55000/???, loss: 7.3167219161987305 epoch: 0, update in batch 56000/???, loss: 6.285423278808594 epoch: 0, update in batch 57000/???, loss: 7.417998313903809 epoch: 0, update in batch 58000/???, loss: 6.437861442565918 epoch: 0, update in batch 59000/???, loss: 6.522177219390869 epoch: 0, update in batch 60000/???, loss: 5.9156928062438965 epoch: 0, update in batch 61000/???, loss: 4.946429252624512 epoch: 0, update in batch 62000/???, loss: 6.633675575256348 epoch: 0, update in batch 63000/???, loss: 7.357038974761963 epoch: 0, update in batch 64000/???, loss: 5.774768352508545 epoch: 0, update in batch 65000/???, loss: 6.289044380187988 epoch: 0, update in batch 66000/???, loss: 6.127488136291504 epoch: 0, update in batch 67000/???, loss: 5.059685230255127 epoch: 0, update in batch 68000/???, loss: 6.5439910888671875 epoch: 0, update in batch 69000/???, loss: 6.679286956787109 epoch: 0, update in batch 70000/???, loss: 7.2232346534729 epoch: 0, update in batch 71000/???, loss: 6.13685941696167 epoch: 0, update in batch 72000/???, loss: 5.766592025756836 epoch: 0, update in batch 73000/???, loss: 6.772070407867432 epoch: 0, update in batch 74000/???, loss: 7.369122505187988 epoch: 0, update in batch 75000/???, loss: 6.598935127258301 epoch: 0, update in batch 76000/???, loss: 5.948511600494385 epoch: 0, update in batch 77000/???, loss: 6.507765769958496 epoch: 0, update in batch 78000/???, loss: 5.09373664855957 epoch: 0, update in batch 79000/???, loss: 5.9862494468688965 epoch: 0, update in batch 80000/???, loss: 6.106108665466309 epoch: 0, update in batch 81000/???, loss: 5.2747578620910645 epoch: 0, update in batch 82000/???, loss: 6.324326515197754 epoch: 0, update in batch 83000/???, loss: 5.914392471313477 epoch: 0, update in batch 84000/???, loss: 6.641409873962402 epoch: 0, update in batch 85000/???, loss: 6.287321090698242 epoch: 0, update in batch 86000/???, loss: 6.510883331298828 epoch: 0, update in batch 87000/???, loss: 6.458550930023193 epoch: 0, update in batch 88000/???, loss: 6.07730770111084 epoch: 0, update in batch 89000/???, loss: 6.2387471199035645 epoch: 0, update in batch 90000/???, loss: 5.63344669342041 epoch: 0, update in batch 91000/???, loss: 6.277956962585449 epoch: 0, update in batch 92000/???, loss: 6.841054439544678 epoch: 0, update in batch 93000/???, loss: 6.458809852600098 epoch: 0, update in batch 94000/???, loss: 7.471741676330566 epoch: 0, update in batch 95000/???, loss: 6.461136817932129 epoch: 0, update in batch 96000/???, loss: 5.718675136566162 epoch: 0, update in batch 97000/???, loss: 4.4265007972717285 epoch: 0, update in batch 98000/???, loss: 7.05142879486084 epoch: 0, update in batch 99000/???, loss: 6.341854572296143 epoch: 0, update in batch 100000/???, loss: 6.834918022155762 epoch: 0, update in batch 101000/???, loss: 5.367598056793213 epoch: 0, update in batch 102000/???, loss: 5.716221809387207 epoch: 0, update in batch 103000/???, loss: 6.9465742111206055 epoch: 0, update in batch 104000/???, loss: 5.976019382476807 epoch: 0, update in batch 105000/???, loss: 6.125661849975586 epoch: 0, update in batch 106000/???, loss: 6.724229335784912 epoch: 0, update in batch 107000/???, loss: 6.446004390716553 epoch: 0, update in batch 108000/???, loss: 6.4710845947265625 epoch: 0, update in batch 109000/???, loss: 6.5926103591918945 epoch: 0, update in batch 110000/???, loss: 6.966839790344238 epoch: 0, update in batch 111000/???, loss: 7.263918876647949 epoch: 0, update in batch 112000/???, loss: 6.7561750411987305 epoch: 0, update in batch 113000/???, loss: 6.142555236816406 epoch: 0, update in batch 114000/???, loss: 5.974082946777344 epoch: 0, update in batch 115000/???, loss: 5.565796852111816 epoch: 0, update in batch 116000/???, loss: 6.4826202392578125 epoch: 0, update in batch 117000/???, loss: 5.643266201019287 epoch: 0, update in batch 118000/???, loss: 6.360909461975098 epoch: 0, update in batch 119000/???, loss: 5.4074201583862305 epoch: 0, update in batch 120000/???, loss: 7.1339569091796875 epoch: 0, update in batch 121000/???, loss: 6.786561012268066 epoch: 0, update in batch 122000/???, loss: 6.329574108123779 epoch: 0, update in batch 123000/???, loss: 7.21968936920166 epoch: 0, update in batch 124000/???, loss: 5.351359844207764 epoch: 0, update in batch 125000/???, loss: 7.962380886077881 epoch: 0, update in batch 126000/???, loss: 6.351782321929932 epoch: 0, update in batch 127000/???, loss: 6.8343048095703125 epoch: 0, update in batch 128000/???, loss: 6.129800319671631 epoch: 0, update in batch 129000/???, loss: 6.68627405166626 epoch: 0, update in batch 130000/???, loss: 6.498664855957031 epoch: 0, update in batch 131000/???, loss: 5.724549293518066 epoch: 0, update in batch 132000/???, loss: 7.041095733642578 epoch: 0, update in batch 133000/???, loss: 5.901988983154297 epoch: 0, update in batch 134000/???, loss: 6.055495262145996 epoch: 0, update in batch 135000/???, loss: 6.363399982452393 epoch: 0, update in batch 136000/???, loss: 7.45733642578125 epoch: 0, update in batch 137000/???, loss: 6.960203647613525 epoch: 0, update in batch 138000/???, loss: 6.986503601074219 epoch: 0, update in batch 139000/???, loss: 5.7938127517700195 epoch: 0, update in batch 140000/???, loss: 5.559916019439697 epoch: 0, update in batch 141000/???, loss: 5.551616668701172 epoch: 0, update in batch 142000/???, loss: 5.386819839477539 epoch: 0, update in batch 143000/???, loss: 6.826618194580078 epoch: 0, update in batch 144000/???, loss: 6.106345176696777 epoch: 0, update in batch 145000/???, loss: 6.812024116516113 epoch: 0, update in batch 146000/???, loss: 6.347486972808838 epoch: 0, update in batch 147000/???, loss: 6.20189094543457 epoch: 0, update in batch 148000/???, loss: 5.5717034339904785 epoch: 0, update in batch 149000/???, loss: 6.884232521057129 epoch: 0, update in batch 150000/???, loss: 6.8074846267700195 epoch: 0, update in batch 151000/???, loss: 7.028794288635254 epoch: 0, update in batch 152000/???, loss: 5.201214790344238 epoch: 0, update in batch 153000/???, loss: 5.1864013671875 epoch: 0, update in batch 154000/???, loss: 6.4473114013671875 epoch: 0, update in batch 155000/???, loss: 4.9203643798828125 epoch: 0, update in batch 156000/???, loss: 6.829309940338135 epoch: 0, update in batch 157000/???, loss: 7.045801639556885 epoch: 0, update in batch 158000/???, loss: 6.4073967933654785 epoch: 0, update in batch 159000/???, loss: 6.494145393371582 epoch: 0, update in batch 160000/???, loss: 6.682474613189697 epoch: 0, update in batch 161000/???, loss: 5.125617980957031 epoch: 0, update in batch 162000/???, loss: 5.915367126464844 epoch: 0, update in batch 163000/???, loss: 6.4779157638549805 epoch: 0, update in batch 164000/???, loss: 5.547584533691406 epoch: 0, update in batch 165000/???, loss: 6.134579181671143 epoch: 0, update in batch 166000/???, loss: 5.300144672393799 epoch: 0, update in batch 167000/???, loss: 6.53488826751709 epoch: 0, update in batch 168000/???, loss: 6.711917877197266 epoch: 0, update in batch 169000/???, loss: 7.0150322914123535 epoch: 0, update in batch 170000/???, loss: 5.681846618652344 epoch: 0, update in batch 171000/???, loss: 6.583130836486816 epoch: 0, update in batch 172000/???, loss: 6.411820411682129 epoch: 0, update in batch 173000/???, loss: 5.725490093231201 epoch: 0, update in batch 174000/???, loss: 6.651374816894531 epoch: 0, update in batch 175000/???, loss: 5.800152778625488 epoch: 0, update in batch 176000/???, loss: 6.862998962402344 epoch: 0, update in batch 177000/???, loss: 6.668658256530762 epoch: 0, update in batch 178000/???, loss: 6.519270896911621 epoch: 0, update in batch 179000/???, loss: 6.716788291931152 epoch: 0, update in batch 180000/???, loss: 6.675846099853516 epoch: 0, update in batch 181000/???, loss: 6.598060607910156 epoch: 0, update in batch 182000/???, loss: 6.638599395751953 epoch: 0, update in batch 183000/???, loss: 5.693145275115967 epoch: 0, update in batch 184000/???, loss: 5.175653457641602 epoch: 0, update in batch 185000/???, loss: 6.659600734710693 epoch: 0, update in batch 186000/???, loss: 5.782421112060547 epoch: 0, update in batch 187000/???, loss: 6.1736297607421875 epoch: 0, update in batch 188000/???, loss: 5.38541316986084 epoch: 0, update in batch 189000/???, loss: 6.238187789916992 epoch: 0, update in batch 190000/???, loss: 6.10030460357666 epoch: 0, update in batch 191000/???, loss: 6.680960655212402 epoch: 0, update in batch 192000/???, loss: 6.600944519042969 epoch: 0, update in batch 193000/???, loss: 6.171700477600098 epoch: 0, update in batch 194000/???, loss: 7.250021934509277 epoch: 0, update in batch 195000/???, loss: 5.968771934509277 epoch: 0, update in batch 196000/???, loss: 7.107605934143066 epoch: 0, update in batch 197000/???, loss: 6.743283748626709 epoch: 0, update in batch 198000/???, loss: 7.130635738372803 epoch: 0, update in batch 199000/???, loss: 6.37470817565918 epoch: 0, update in batch 200000/???, loss: 6.050590515136719 epoch: 0, update in batch 201000/???, loss: 5.468177318572998 epoch: 0, update in batch 202000/???, loss: 6.343471527099609 epoch: 0, update in batch 203000/???, loss: 6.890538692474365 epoch: 0, update in batch 204000/???, loss: 7.018721580505371 epoch: 0, update in batch 205000/???, loss: 6.131939888000488 epoch: 0, update in batch 206000/???, loss: 6.219918251037598 epoch: 0, update in batch 207000/???, loss: 5.858460426330566 epoch: 0, update in batch 208000/???, loss: 6.33021354675293 epoch: 0, update in batch 209000/???, loss: 6.249329566955566 epoch: 0, update in batch 210000/???, loss: 6.263474941253662 epoch: 0, update in batch 211000/???, loss: 6.731234550476074 epoch: 0, update in batch 212000/???, loss: 5.978096961975098 epoch: 0, update in batch 213000/???, loss: 5.148629188537598 epoch: 0, update in batch 214000/???, loss: 6.79285192489624 epoch: 0, update in batch 215000/???, loss: 5.943106651306152 epoch: 0, update in batch 216000/???, loss: 5.749272346496582 epoch: 0, update in batch 217000/???, loss: 6.991009712219238 epoch: 0, update in batch 218000/???, loss: 6.21205997467041 epoch: 0, update in batch 219000/???, loss: 7.519427299499512 epoch: 0, update in batch 220000/???, loss: 5.699267387390137 epoch: 0, update in batch 221000/???, loss: 6.05304479598999 epoch: 0, update in batch 222000/???, loss: 6.422593116760254 epoch: 0, update in batch 223000/???, loss: 6.179877281188965 epoch: 0, update in batch 224000/???, loss: 4.841546058654785 epoch: 0, update in batch 225000/???, loss: 6.666176795959473 epoch: 0, update in batch 226000/???, loss: 5.994054794311523 epoch: 0, update in batch 227000/???, loss: 6.792928218841553 epoch: 0, update in batch 228000/???, loss: 6.9571661949157715 epoch: 0, update in batch 229000/???, loss: 6.198942184448242 epoch: 0, update in batch 230000/???, loss: 5.944539546966553 epoch: 0, update in batch 231000/???, loss: 6.188899040222168 epoch: 0, update in batch 232000/???, loss: 5.826596260070801 epoch: 0, update in batch 233000/???, loss: 5.728386878967285 epoch: 0, update in batch 234000/???, loss: 7.6024885177612305 epoch: 0, update in batch 235000/???, loss: 6.728615760803223 epoch: 0, update in batch 236000/???, loss: 6.2461137771606445 epoch: 0, update in batch 237000/???, loss: 6.3110551834106445 epoch: 0, update in batch 238000/???, loss: 6.12617826461792 epoch: 0, update in batch 239000/???, loss: 6.6068243980407715 epoch: 0, update in batch 240000/???, loss: 7.015429496765137 epoch: 0, update in batch 241000/???, loss: 8.444561004638672 epoch: 0, update in batch 242000/???, loss: 7.289303779602051 epoch: 0, update in batch 243000/???, loss: 6.260491371154785 epoch: 0, update in batch 244000/???, loss: 7.60237979888916 epoch: 0, update in batch 245000/???, loss: 6.295613765716553 epoch: 0, update in batch 246000/???, loss: 5.929107666015625 epoch: 0, update in batch 247000/???, loss: 5.835566997528076 epoch: 0, update in batch 248000/???, loss: 5.837784290313721 epoch: 0, update in batch 249000/???, loss: 5.972233772277832 epoch: 0, update in batch 250000/???, loss: 6.0488996505737305 epoch: 0, update in batch 251000/???, loss: 5.712280750274658 epoch: 0, update in batch 252000/???, loss: 5.9513702392578125 epoch: 0, update in batch 253000/???, loss: 5.636294364929199 epoch: 0, update in batch 254000/???, loss: 5.91803503036499 epoch: 0, update in batch 255000/???, loss: 7.285937309265137 epoch: 0, update in batch 256000/???, loss: 6.4795637130737305 epoch: 0, update in batch 257000/???, loss: 6.0709991455078125 epoch: 0, update in batch 258000/???, loss: 5.8723649978637695 epoch: 0, update in batch 259000/???, loss: 5.174002647399902 epoch: 0, update in batch 260000/???, loss: 6.504033088684082 epoch: 0, update in batch 261000/???, loss: 7.088961601257324 epoch: 0, update in batch 262000/???, loss: 6.2242960929870605 epoch: 0, update in batch 263000/???, loss: 5.970286846160889 epoch: 0, update in batch 264000/???, loss: 5.961676597595215 epoch: 0, update in batch 265000/???, loss: 6.170080661773682 epoch: 0, update in batch 266000/???, loss: 5.477972507476807 epoch: 0, update in batch 267000/???, loss: 6.188825607299805 epoch: 0, update in batch 268000/???, loss: 6.518698215484619 epoch: 0, update in batch 269000/???, loss: 5.663434028625488 epoch: 0, update in batch 270000/???, loss: 5.978742599487305 epoch: 0, update in batch 271000/???, loss: 6.217379093170166 epoch: 0, update in batch 272000/???, loss: 5.426600933074951 epoch: 0, update in batch 273000/???, loss: 6.7220964431762695 epoch: 0, update in batch 274000/???, loss: 4.276306629180908 epoch: 0, update in batch 275000/???, loss: 5.420112609863281 epoch: 0, update in batch 276000/???, loss: 5.934456825256348 epoch: 0, update in batch 277000/???, loss: 7.186459541320801 epoch: 0, update in batch 278000/???, loss: 6.126835823059082 epoch: 0, update in batch 279000/???, loss: 5.727339267730713 epoch: 0, update in batch 280000/???, loss: 5.725864410400391 epoch: 0, update in batch 281000/???, loss: 5.47005033493042 epoch: 0, update in batch 282000/???, loss: 6.217499732971191 epoch: 0, update in batch 283000/???, loss: 6.022196292877197 epoch: 0, update in batch 284000/???, loss: 5.932379722595215 epoch: 0, update in batch 285000/???, loss: 6.321987628936768 epoch: 0, update in batch 286000/???, loss: 7.480570316314697 epoch: 0, update in batch 287000/???, loss: 5.169373512268066 epoch: 0, update in batch 288000/???, loss: 6.301320552825928 epoch: 0, update in batch 289000/???, loss: 6.4635009765625 epoch: 0, update in batch 290000/???, loss: 6.8701887130737305 epoch: 0, update in batch 291000/???, loss: 6.036175727844238 epoch: 0, update in batch 292000/???, loss: 6.705732822418213 epoch: 0, update in batch 293000/???, loss: 6.99608850479126 epoch: 0, update in batch 294000/???, loss: 6.50225305557251 epoch: 0, update in batch 295000/???, loss: 6.03929328918457 epoch: 0, update in batch 296000/???, loss: 5.498082160949707 epoch: 0, update in batch 297000/???, loss: 6.04677677154541 epoch: 0, update in batch 298000/???, loss: 6.482898712158203 epoch: 0, update in batch 299000/???, loss: 7.235076904296875 epoch: 0, update in batch 300000/???, loss: 6.019383907318115 epoch: 0, update in batch 301000/???, loss: 7.082001686096191 epoch: 0, update in batch 302000/???, loss: 6.447659492492676 epoch: 0, update in batch 303000/???, loss: 5.94022798538208 epoch: 0, update in batch 304000/???, loss: 6.459266662597656 epoch: 0, update in batch 305000/???, loss: 6.281588077545166 epoch: 0, update in batch 306000/???, loss: 7.022011756896973 epoch: 0, update in batch 307000/???, loss: 6.1802263259887695 epoch: 0, update in batch 308000/???, loss: 4.189492225646973 epoch: 0, update in batch 309000/???, loss: 6.7040696144104 epoch: 0, update in batch 310000/???, loss: 6.589522361755371 epoch: 0, update in batch 311000/???, loss: 6.243889808654785 epoch: 0, update in batch 312000/???, loss: 5.490180015563965 epoch: 0, update in batch 313000/???, loss: 5.9699201583862305 epoch: 0, update in batch 314000/???, loss: 7.321981906890869 epoch: 0, update in batch 315000/???, loss: 4.731215953826904 epoch: 0, update in batch 316000/???, loss: 5.845946788787842 epoch: 0, update in batch 317000/???, loss: 5.917788505554199 epoch: 0, update in batch 318000/???, loss: 6.420014381408691 epoch: 0, update in batch 319000/???, loss: 6.550830841064453 epoch: 0, update in batch 320000/???, loss: 6.751360893249512 epoch: 0, update in batch 321000/???, loss: 5.025134086608887 epoch: 0, update in batch 322000/???, loss: 6.368621826171875 epoch: 0, update in batch 323000/???, loss: 6.2042083740234375 epoch: 0, update in batch 324000/???, loss: 6.173147678375244 epoch: 0, update in batch 325000/???, loss: 5.865999221801758 epoch: 0, update in batch 326000/???, loss: 6.844902992248535 epoch: 0, update in batch 327000/???, loss: 6.080742359161377 epoch: 0, update in batch 328000/???, loss: 5.41788387298584 epoch: 0, update in batch 329000/???, loss: 5.831374645233154 epoch: 0, update in batch 330000/???, loss: 6.4492506980896 epoch: 0, update in batch 331000/???, loss: 6.220627784729004 epoch: 0, update in batch 332000/???, loss: 5.880006313323975 epoch: 0, update in batch 333000/???, loss: 6.806972503662109 epoch: 0, update in batch 334000/???, loss: 7.165728569030762 epoch: 0, update in batch 335000/???, loss: 6.322948932647705 epoch: 0, update in batch 336000/???, loss: 6.206046104431152 epoch: 0, update in batch 337000/???, loss: 6.097958564758301 epoch: 0, update in batch 338000/???, loss: 6.7682952880859375 epoch: 0, update in batch 339000/???, loss: 5.2390642166137695 epoch: 0, update in batch 340000/???, loss: 6.913119316101074
train(train_dataset_back, model_back, 1, 64)
epoch: 0, update in batch 0/???, loss: 11.3253755569458 epoch: 0, update in batch 1000/???, loss: 5.709358215332031 epoch: 0, update in batch 2000/???, loss: 7.989391326904297 epoch: 0, update in batch 3000/???, loss: 6.578714847564697 epoch: 0, update in batch 4000/???, loss: 7.051873207092285 epoch: 0, update in batch 5000/???, loss: 6.85653018951416 epoch: 0, update in batch 6000/???, loss: 6.812790870666504 epoch: 0, update in batch 7000/???, loss: 6.9604010581970215 epoch: 0, update in batch 8000/???, loss: 6.798591613769531 epoch: 0, update in batch 9000/???, loss: 6.415241241455078 epoch: 0, update in batch 10000/???, loss: 6.6636223793029785 epoch: 0, update in batch 11000/???, loss: 6.593747138977051 epoch: 0, update in batch 12000/???, loss: 6.914702415466309 epoch: 0, update in batch 13000/???, loss: 5.542675971984863 epoch: 0, update in batch 14000/???, loss: 6.5461883544921875 epoch: 0, update in batch 15000/???, loss: 7.507067680358887 epoch: 0, update in batch 16000/???, loss: 5.425755500793457 epoch: 0, update in batch 17000/???, loss: 6.285205841064453 epoch: 0, update in batch 18000/???, loss: 4.223124027252197 epoch: 0, update in batch 19000/???, loss: 6.530254364013672 epoch: 0, update in batch 20000/???, loss: 6.091847896575928 epoch: 0, update in batch 21000/???, loss: 7.088344573974609 epoch: 0, update in batch 22000/???, loss: 5.925537109375 epoch: 0, update in batch 23000/???, loss: 6.3628082275390625 epoch: 0, update in batch 24000/???, loss: 6.604581356048584 epoch: 0, update in batch 25000/???, loss: 6.2706499099731445 epoch: 0, update in batch 26000/???, loss: 6.114742755889893 epoch: 0, update in batch 27000/???, loss: 5.686783790588379 epoch: 0, update in batch 28000/???, loss: 5.5114521980285645 epoch: 0, update in batch 29000/???, loss: 6.999403953552246 epoch: 0, update in batch 30000/???, loss: 5.834499359130859 epoch: 0, update in batch 31000/???, loss: 5.873156547546387 epoch: 0, update in batch 32000/???, loss: 6.246962547302246 epoch: 0, update in batch 33000/???, loss: 6.742733955383301 epoch: 0, update in batch 34000/???, loss: 6.832881927490234 epoch: 0, update in batch 35000/???, loss: 6.625868320465088 epoch: 0, update in batch 36000/???, loss: 6.653105735778809 epoch: 0, update in batch 37000/???, loss: 6.104651927947998 epoch: 0, update in batch 38000/???, loss: 6.301898002624512 epoch: 0, update in batch 39000/???, loss: 7.377936363220215 epoch: 0, update in batch 40000/???, loss: 6.26895809173584 epoch: 0, update in batch 41000/???, loss: 6.602926731109619 epoch: 0, update in batch 42000/???, loss: 6.419803619384766 epoch: 0, update in batch 43000/???, loss: 7.187136650085449 epoch: 0, update in batch 44000/???, loss: 6.382015705108643 epoch: 0, update in batch 45000/???, loss: 6.044090747833252 epoch: 0, update in batch 46000/???, loss: 5.707688808441162 epoch: 0, update in batch 47000/???, loss: 7.007757663726807 epoch: 0, update in batch 48000/???, loss: 5.365390300750732 epoch: 0, update in batch 49000/???, loss: 5.510242938995361 epoch: 0, update in batch 50000/???, loss: 5.955991268157959 epoch: 0, update in batch 51000/???, loss: 6.2313032150268555 epoch: 0, update in batch 52000/???, loss: 8.19306468963623 epoch: 0, update in batch 53000/???, loss: 6.345375061035156 epoch: 0, update in batch 54000/???, loss: 7.044759273529053 epoch: 0, update in batch 55000/???, loss: 6.2544779777526855 epoch: 0, update in batch 56000/???, loss: 6.315605163574219 epoch: 0, update in batch 57000/???, loss: 5.632706642150879 epoch: 0, update in batch 58000/???, loss: 6.0897536277771 epoch: 0, update in batch 59000/???, loss: 5.562952518463135 epoch: 0, update in batch 60000/???, loss: 5.519134044647217 epoch: 0, update in batch 61000/???, loss: 6.394771099090576 epoch: 0, update in batch 62000/???, loss: 6.147246360778809 epoch: 0, update in batch 63000/???, loss: 5.798914909362793 epoch: 0, update in batch 64000/???, loss: 6.026059627532959 epoch: 0, update in batch 65000/???, loss: 6.4533233642578125 epoch: 0, update in batch 66000/???, loss: 6.383795738220215 epoch: 0, update in batch 67000/???, loss: 6.466322898864746 epoch: 0, update in batch 68000/???, loss: 6.8227715492248535 epoch: 0, update in batch 69000/???, loss: 6.283398151397705 epoch: 0, update in batch 70000/???, loss: 4.547608375549316 epoch: 0, update in batch 71000/???, loss: 6.008975028991699 epoch: 0, update in batch 72000/???, loss: 5.674825191497803 epoch: 0, update in batch 73000/???, loss: 5.134644508361816 epoch: 0, update in batch 74000/???, loss: 6.906868934631348 epoch: 0, update in batch 75000/???, loss: 6.672898292541504 epoch: 0, update in batch 76000/???, loss: 5.813290596008301 epoch: 0, update in batch 77000/???, loss: 6.296219825744629 epoch: 0, update in batch 78000/???, loss: 6.531443119049072 epoch: 0, update in batch 79000/???, loss: 6.437461853027344 epoch: 0, update in batch 80000/???, loss: 6.2280778884887695 epoch: 0, update in batch 81000/???, loss: 6.805241584777832 epoch: 0, update in batch 82000/???, loss: 7.044824123382568 epoch: 0, update in batch 83000/???, loss: 7.348274230957031 epoch: 0, update in batch 84000/???, loss: 5.826806545257568 epoch: 0, update in batch 85000/???, loss: 5.474950313568115 epoch: 0, update in batch 86000/???, loss: 6.497323036193848 epoch: 0, update in batch 87000/???, loss: 5.88934850692749 epoch: 0, update in batch 88000/???, loss: 5.371798038482666 epoch: 0, update in batch 89000/???, loss: 6.093968391418457 epoch: 0, update in batch 90000/???, loss: 6.115981578826904 epoch: 0, update in batch 91000/???, loss: 6.504927158355713 epoch: 0, update in batch 92000/???, loss: 6.239808082580566 epoch: 0, update in batch 93000/???, loss: 5.384994983673096 epoch: 0, update in batch 94000/???, loss: 6.422779083251953 epoch: 0, update in batch 95000/???, loss: 7.163965702056885 epoch: 0, update in batch 96000/???, loss: 6.44806432723999 epoch: 0, update in batch 97000/???, loss: 6.153664588928223 epoch: 0, update in batch 98000/???, loss: 5.9013776779174805 epoch: 0, update in batch 99000/???, loss: 6.198166847229004 epoch: 0, update in batch 100000/???, loss: 5.752341270446777 epoch: 0, update in batch 101000/???, loss: 6.455883979797363 epoch: 0, update in batch 102000/???, loss: 5.270313262939453 epoch: 0, update in batch 103000/???, loss: 6.475237846374512 epoch: 0, update in batch 104000/???, loss: 6.2444844245910645 epoch: 0, update in batch 105000/???, loss: 6.1563720703125 epoch: 0, update in batch 106000/???, loss: 6.12777853012085 epoch: 0, update in batch 107000/???, loss: 6.449145317077637 epoch: 0, update in batch 108000/???, loss: 6.515239715576172 epoch: 0, update in batch 109000/???, loss: 5.6317644119262695 epoch: 0, update in batch 110000/???, loss: 6.09606409072876 epoch: 0, update in batch 111000/???, loss: 7.069797515869141 epoch: 0, update in batch 112000/???, loss: 7.456076145172119 epoch: 0, update in batch 113000/???, loss: 6.668386936187744 epoch: 0, update in batch 114000/???, loss: 7.705430507659912 epoch: 0, update in batch 115000/???, loss: 6.983656883239746 epoch: 0, update in batch 116000/???, loss: 6.320417404174805 epoch: 0, update in batch 117000/???, loss: 7.184473991394043 epoch: 0, update in batch 118000/???, loss: 6.603268623352051 epoch: 0, update in batch 119000/???, loss: 6.670085906982422 epoch: 0, update in batch 120000/???, loss: 6.748586177825928 epoch: 0, update in batch 121000/???, loss: 6.353959560394287 epoch: 0, update in batch 122000/???, loss: 5.138751029968262 epoch: 0, update in batch 123000/???, loss: 6.507109642028809 epoch: 0, update in batch 124000/???, loss: 6.360246181488037 epoch: 0, update in batch 125000/???, loss: 7.164086818695068 epoch: 0, update in batch 126000/???, loss: 5.610747337341309 epoch: 0, update in batch 127000/???, loss: 5.066179275512695 epoch: 0, update in batch 128000/???, loss: 5.688697814941406 epoch: 0, update in batch 129000/???, loss: 6.960330963134766 epoch: 0, update in batch 130000/???, loss: 5.818534851074219 epoch: 0, update in batch 131000/???, loss: 6.186715602874756 epoch: 0, update in batch 132000/???, loss: 5.825492858886719 epoch: 0, update in batch 133000/???, loss: 5.576340675354004 epoch: 0, update in batch 134000/???, loss: 5.503821849822998 epoch: 0, update in batch 135000/???, loss: 6.428965091705322 epoch: 0, update in batch 136000/???, loss: 5.102448463439941 epoch: 0, update in batch 137000/???, loss: 6.239314556121826 epoch: 0, update in batch 138000/???, loss: 6.028595447540283 epoch: 0, update in batch 139000/???, loss: 6.407244682312012 epoch: 0, update in batch 140000/???, loss: 5.597055912017822 epoch: 0, update in batch 141000/???, loss: 5.823704719543457 epoch: 0, update in batch 142000/???, loss: 6.665535926818848 epoch: 0, update in batch 143000/???, loss: 5.5736894607543945 epoch: 0, update in batch 144000/???, loss: 6.723180294036865 epoch: 0, update in batch 145000/???, loss: 6.378345489501953 epoch: 0, update in batch 146000/???, loss: 5.6936845779418945 epoch: 0, update in batch 147000/???, loss: 5.761658668518066 epoch: 0, update in batch 148000/???, loss: 5.580254077911377 epoch: 0, update in batch 149000/???, loss: 5.733176231384277 epoch: 0, update in batch 150000/???, loss: 6.901691436767578 epoch: 0, update in batch 151000/???, loss: 6.5111589431762695 epoch: 0, update in batch 152000/???, loss: 6.184727668762207 epoch: 0, update in batch 153000/???, loss: 7.407107353210449 epoch: 0, update in batch 154000/???, loss: 6.499199867248535 epoch: 0, update in batch 155000/???, loss: 5.143393516540527 epoch: 0, update in batch 156000/???, loss: 7.60940408706665 epoch: 0, update in batch 157000/???, loss: 6.766045570373535 epoch: 0, update in batch 158000/???, loss: 5.268759727478027 epoch: 0, update in batch 159000/???, loss: 7.558129787445068 epoch: 0, update in batch 160000/???, loss: 8.016000747680664 epoch: 0, update in batch 161000/???, loss: 5.959166526794434 epoch: 0, update in batch 162000/???, loss: 5.499085426330566 epoch: 0, update in batch 163000/???, loss: 6.581662654876709 epoch: 0, update in batch 164000/???, loss: 6.681334495544434 epoch: 0, update in batch 165000/???, loss: 7.817207336425781 epoch: 0, update in batch 166000/???, loss: 6.524381160736084 epoch: 0, update in batch 167000/???, loss: 5.903864860534668 epoch: 0, update in batch 168000/???, loss: 5.6087260246276855 epoch: 0, update in batch 169000/???, loss: 5.742824554443359 epoch: 0, update in batch 170000/???, loss: 6.129671096801758 epoch: 0, update in batch 171000/???, loss: 5.879034519195557 epoch: 0, update in batch 172000/???, loss: 6.322129249572754 epoch: 0, update in batch 173000/???, loss: 6.805352210998535 epoch: 0, update in batch 174000/???, loss: 7.162431240081787 epoch: 0, update in batch 175000/???, loss: 6.123959541320801 epoch: 0, update in batch 176000/???, loss: 7.544029235839844 epoch: 0, update in batch 177000/???, loss: 5.4254021644592285 epoch: 0, update in batch 178000/???, loss: 5.784268379211426 epoch: 0, update in batch 179000/???, loss: 5.8633856773376465 epoch: 0, update in batch 180000/???, loss: 6.556314945220947 epoch: 0, update in batch 181000/???, loss: 5.215446472167969 epoch: 0, update in batch 182000/???, loss: 6.079234600067139 epoch: 0, update in batch 183000/???, loss: 7.234827995300293 epoch: 0, update in batch 184000/???, loss: 5.249889373779297 epoch: 0, update in batch 185000/???, loss: 5.083311080932617 epoch: 0, update in batch 186000/???, loss: 6.061867713928223 epoch: 0, update in batch 187000/???, loss: 6.060431480407715 epoch: 0, update in batch 188000/???, loss: 5.572680950164795 epoch: 0, update in batch 189000/???, loss: 5.991988182067871 epoch: 0, update in batch 190000/???, loss: 6.521245002746582 epoch: 0, update in batch 191000/???, loss: 5.128615379333496 epoch: 0, update in batch 192000/???, loss: 5.616750717163086 epoch: 0, update in batch 193000/???, loss: 6.1465044021606445 epoch: 0, update in batch 194000/???, loss: 5.93985652923584 epoch: 0, update in batch 195000/???, loss: 6.268892765045166 epoch: 0, update in batch 196000/???, loss: 5.928576469421387 epoch: 0, update in batch 197000/???, loss: 5.257290363311768 epoch: 0, update in batch 198000/???, loss: 6.6432952880859375 epoch: 0, update in batch 199000/???, loss: 6.898074150085449 epoch: 0, update in batch 200000/???, loss: 7.042447566986084 epoch: 0, update in batch 201000/???, loss: 7.104043483734131 epoch: 0, update in batch 202000/???, loss: 6.238812446594238 epoch: 0, update in batch 203000/???, loss: 6.773525238037109 epoch: 0, update in batch 204000/???, loss: 5.054592132568359 epoch: 0, update in batch 205000/???, loss: 6.854428768157959 epoch: 0, update in batch 206000/???, loss: 5.9983601570129395 epoch: 0, update in batch 207000/???, loss: 5.236695766448975 epoch: 0, update in batch 208000/???, loss: 6.086891174316406 epoch: 0, update in batch 209000/???, loss: 6.134495258331299 epoch: 0, update in batch 210000/???, loss: 6.52248477935791 epoch: 0, update in batch 211000/???, loss: 6.028376579284668 epoch: 0, update in batch 212000/???, loss: 6.140281677246094 epoch: 0, update in batch 213000/???, loss: 6.066422462463379 epoch: 0, update in batch 214000/???, loss: 6.868189334869385 epoch: 0, update in batch 215000/???, loss: 6.641358852386475 epoch: 0, update in batch 216000/???, loss: 6.818638801574707 epoch: 0, update in batch 217000/???, loss: 6.40252685546875 epoch: 0, update in batch 218000/???, loss: 5.561617851257324 epoch: 0, update in batch 219000/???, loss: 6.434267997741699 epoch: 0, update in batch 220000/???, loss: 6.33272123336792 epoch: 0, update in batch 221000/???, loss: 5.75616979598999 epoch: 0, update in batch 222000/???, loss: 6.477814674377441 epoch: 0, update in batch 223000/???, loss: 5.259497165679932 epoch: 0, update in batch 224000/???, loss: 5.8639655113220215 epoch: 0, update in batch 225000/???, loss: 6.469706058502197 epoch: 0, update in batch 226000/???, loss: 5.707249164581299 epoch: 0, update in batch 227000/???, loss: 6.394181251525879 epoch: 0, update in batch 228000/???, loss: 5.048886299133301 epoch: 0, update in batch 229000/???, loss: 5.842928409576416 epoch: 0, update in batch 230000/???, loss: 5.627688407897949 epoch: 0, update in batch 231000/???, loss: 7.950299263000488 epoch: 0, update in batch 232000/???, loss: 6.771368503570557 epoch: 0, update in batch 233000/???, loss: 5.787235260009766 epoch: 0, update in batch 234000/???, loss: 5.6070780754089355 epoch: 0, update in batch 235000/???, loss: 6.060035705566406 epoch: 0, update in batch 236000/???, loss: 6.894829750061035 epoch: 0, update in batch 237000/???, loss: 5.672856330871582 epoch: 0, update in batch 238000/???, loss: 5.054213523864746 epoch: 0, update in batch 239000/???, loss: 6.484643459320068 epoch: 0, update in batch 240000/???, loss: 5.800728797912598 epoch: 0, update in batch 241000/???, loss: 5.148013591766357 epoch: 0, update in batch 242000/???, loss: 5.529184818267822 epoch: 0, update in batch 243000/???, loss: 5.959448337554932 epoch: 0, update in batch 244000/???, loss: 6.762448787689209 epoch: 0, update in batch 245000/???, loss: 4.907589912414551 epoch: 0, update in batch 246000/???, loss: 6.275182723999023 epoch: 0, update in batch 247000/???, loss: 5.7234015464782715 epoch: 0, update in batch 248000/???, loss: 6.119207859039307 epoch: 0, update in batch 249000/???, loss: 5.297057151794434 epoch: 0, update in batch 250000/???, loss: 5.924614906311035 epoch: 0, update in batch 251000/???, loss: 6.651083469390869 epoch: 0, update in batch 252000/???, loss: 5.7164201736450195 epoch: 0, update in batch 253000/???, loss: 6.105191230773926 epoch: 0, update in batch 254000/???, loss: 5.791018486022949 epoch: 0, update in batch 255000/???, loss: 6.659502983093262 epoch: 0, update in batch 256000/???, loss: 5.613073348999023 epoch: 0, update in batch 257000/???, loss: 7.501049041748047 epoch: 0, update in batch 258000/???, loss: 6.043797492980957 epoch: 0, update in batch 259000/???, loss: 7.3587327003479 epoch: 0, update in batch 260000/???, loss: 6.276612281799316 epoch: 0, update in batch 261000/???, loss: 6.445192813873291 epoch: 0, update in batch 262000/???, loss: 5.0266547203063965 epoch: 0, update in batch 263000/???, loss: 6.404935359954834 epoch: 0, update in batch 264000/???, loss: 6.5042290687561035 epoch: 0, update in batch 265000/???, loss: 6.880773067474365 epoch: 0, update in batch 266000/???, loss: 6.3690643310546875 epoch: 0, update in batch 267000/???, loss: 6.055562973022461 epoch: 0, update in batch 268000/???, loss: 5.796906471252441 epoch: 0, update in batch 269000/???, loss: 5.654962539672852 epoch: 0, update in batch 270000/???, loss: 6.574362277984619 epoch: 0, update in batch 271000/???, loss: 6.256768226623535 epoch: 0, update in batch 272000/???, loss: 6.8345208168029785 epoch: 0, update in batch 273000/???, loss: 6.066469669342041 epoch: 0, update in batch 274000/???, loss: 6.625809669494629 epoch: 0, update in batch 275000/???, loss: 4.762896537780762 epoch: 0, update in batch 276000/???, loss: 6.019833564758301 epoch: 0, update in batch 277000/???, loss: 6.227939605712891 epoch: 0, update in batch 278000/???, loss: 7.046879768371582 epoch: 0, update in batch 279000/???, loss: 6.068551540374756 epoch: 0, update in batch 280000/???, loss: 6.454771995544434 epoch: 0, update in batch 281000/???, loss: 3.9379985332489014 epoch: 0, update in batch 282000/???, loss: 5.615240097045898 epoch: 0, update in batch 283000/???, loss: 5.7963151931762695 epoch: 0, update in batch 284000/???, loss: 6.064437389373779 epoch: 0, update in batch 285000/???, loss: 6.668734073638916 epoch: 0, update in batch 286000/???, loss: 6.776829719543457 epoch: 0, update in batch 287000/???, loss: 6.170516014099121 epoch: 0, update in batch 288000/???, loss: 4.840399742126465 epoch: 0, update in batch 289000/???, loss: 6.333052635192871 epoch: 0, update in batch 290000/???, loss: 5.595047950744629 epoch: 0, update in batch 291000/???, loss: 6.594934940338135 epoch: 0, update in batch 292000/???, loss: 5.950274467468262 epoch: 0, update in batch 293000/???, loss: 6.123660087585449 epoch: 0, update in batch 294000/???, loss: 5.904355049133301 epoch: 0, update in batch 295000/???, loss: 5.8828630447387695 epoch: 0, update in batch 296000/???, loss: 5.604973316192627 epoch: 0, update in batch 297000/???, loss: 4.842469692230225 epoch: 0, update in batch 298000/???, loss: 5.862446308135986 epoch: 0, update in batch 299000/???, loss: 6.90258264541626 epoch: 0, update in batch 300000/???, loss: 5.941957950592041 epoch: 0, update in batch 301000/???, loss: 5.697750568389893 epoch: 0, update in batch 302000/???, loss: 5.973014831542969 epoch: 0, update in batch 303000/???, loss: 5.46022367477417 epoch: 0, update in batch 304000/???, loss: 6.5218095779418945 epoch: 0, update in batch 305000/???, loss: 6.392545700073242 epoch: 0, update in batch 306000/???, loss: 7.080249786376953 epoch: 0, update in batch 307000/???, loss: 6.355096817016602 epoch: 0, update in batch 308000/???, loss: 5.625491619110107 epoch: 0, update in batch 309000/???, loss: 6.805799961090088 epoch: 0, update in batch 310000/???, loss: 6.426385402679443 epoch: 0, update in batch 311000/???, loss: 5.727842807769775 epoch: 0, update in batch 312000/???, loss: 6.9111199378967285 epoch: 0, update in batch 313000/???, loss: 6.40056848526001 epoch: 0, update in batch 314000/???, loss: 6.145076751708984 epoch: 0, update in batch 315000/???, loss: 6.097104072570801 epoch: 0, update in batch 316000/???, loss: 5.39146089553833 epoch: 0, update in batch 317000/???, loss: 6.125569820404053 epoch: 0, update in batch 318000/???, loss: 6.533677577972412 epoch: 0, update in batch 319000/???, loss: 5.944211483001709 epoch: 0, update in batch 320000/???, loss: 6.542410850524902 epoch: 0, update in batch 321000/???, loss: 5.699315071105957 epoch: 0, update in batch 322000/???, loss: 6.251957893371582 epoch: 0, update in batch 323000/???, loss: 5.346350193023682 epoch: 0, update in batch 324000/???, loss: 5.603858470916748 epoch: 0, update in batch 325000/???, loss: 5.740134239196777 epoch: 0, update in batch 326000/???, loss: 5.575300693511963 epoch: 0, update in batch 327000/???, loss: 6.996762752532959 epoch: 0, update in batch 328000/???, loss: 6.28995418548584 epoch: 0, update in batch 329000/???, loss: 4.519123077392578 epoch: 0, update in batch 330000/???, loss: 5.9068121910095215 epoch: 0, update in batch 331000/???, loss: 6.61830997467041 epoch: 0, update in batch 332000/???, loss: 6.063097953796387 epoch: 0, update in batch 333000/???, loss: 6.419328212738037 epoch: 0, update in batch 334000/???, loss: 5.927584648132324 epoch: 0, update in batch 335000/???, loss: 5.527887344360352 epoch: 0, update in batch 336000/???, loss: 6.114096641540527 epoch: 0, update in batch 337000/???, loss: 5.9415082931518555 epoch: 0, update in batch 338000/???, loss: 5.288441181182861 epoch: 0, update in batch 339000/???, loss: 6.611715793609619 epoch: 0, update in batch 340000/???, loss: 6.770573616027832
def predict_probs(left_tokens, right_tokens):
model_front.eval()
model_back.eval()
x_left = torch.tensor([[train_dataset_front.key_to_index[w] if w in key_to_index else train_dataset_front.key_to_index['<unk>'] for w in left_tokens]]).to(device)
x_right = torch.tensor([[train_dataset_front.key_to_index[w] if w in key_to_index else train_dataset_front.key_to_index['<unk>'] for w in right_tokens]]).to(device)
y_pred_left, (state_h_left, state_c_left) = model_front(x_left)
y_pred_right, (state_h_right, state_c_right) = model_back(x_right)
last_word_logits_left = y_pred_left[0][-1]
last_word_logits_right = y_pred_right[0][-1]
probs_left = torch.nn.functional.softmax(last_word_logits_left, dim=0).detach().cpu().numpy()
probs_right = torch.nn.functional.softmax(last_word_logits_right, dim=0).detach().cpu().numpy()
probs = [np.mean(k) for k in zip(probs_left, probs_right)]
top_words = []
for index in range(len(probs)):
if len(top_words) < 30:
top_words.append((probs[index], [index]))
else:
worst_word = None
for word in top_words:
if not worst_word:
worst_word = word
else:
if word[0] < worst_word[0]:
worst_word = word
if worst_word[0] < probs[index] and index != len(probs) - 1:
top_words.remove(worst_word)
top_words.append((probs[index], [index]))
prediction = ''
sum_prob = 0.0
for word in top_words:
sum_prob += word[0]
word_index = word[0]
word_text = index_to_key[word[1][0]]
prediction += f'{word_text}:{word_index} '
prediction += f':{1 - sum_prob}'
return prediction
dev_data = pd.read_csv('dev-0/in.tsv.xz', sep='\t', on_bad_lines='skip', header=None, quoting=csv.QUOTE_NONE)
test_data = pd.read_csv('test-A/in.tsv.xz', sep='\t', on_bad_lines='skip', header=None, quoting=csv.QUOTE_NONE)
with open('dev-0/out.tsv', 'w') as file:
for index, row in dev_data.iterrows():
left_text = clean_text(str(row[6]))
right_text = clean_text(str(row[7]))
left_words = word_tokenize(left_text)
right_words = word_tokenize(right_text)
right_words.reverse()
if len(left_words) < 6 or len(right_words) < 6:
prediction = ':1.0'
else:
prediction = predict_probs(left_words[-5:], right_words[-5:])
file.write(prediction + '\n')
with open('test-A/out.tsv', 'w') as file:
for index, row in test_data.iterrows():
left_text = clean_text(str(row[6]))
right_text = clean_text(str(row[7]))
left_words = word_tokenize(left_text)
right_words = word_tokenize(right_text)
right_words.reverse()
if len(left_words) < 6 or len(right_words) < 6:
prediction = ':1.0'
else:
prediction = predict_probs(left_words[-5:], right_words[-5:])
file.write(prediction + '\n')