challenging-america-word-ga.../run.ipynb

62 KiB

import pandas as pd
import numpy as np
import regex as re
import csv
import torch
from torch import nn
from gensim.models import Word2Vec
from nltk.tokenize import word_tokenize
torch.cuda.empty_cache()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
def clean_text(text):
    text = text.lower().replace('-\\\\\\\\\\\\\\\\n', '').replace('\\\\\\\\\\\\\\\\n', ' ')
    text = re.sub(r'\p{P}', '', text)
    text = text.replace("'t", " not").replace("'s", " is").replace("'ll", " will").replace("'m", " am").replace("'ve", " have")

    return text
train_data = pd.read_csv('train/in.tsv.xz', sep='\t', on_bad_lines='skip', header=None, quoting=csv.QUOTE_NONE)
train_labels = pd.read_csv('train/expected.tsv', sep='\t', on_bad_lines='skip', header=None, quoting=csv.QUOTE_NONE)

train_data = train_data[[6, 7]]
train_data = pd.concat([train_data, train_labels], axis=1)
class TrainCorpus:
    def __init__(self, data):
        self.data = data
        
    def __iter__(self):
        for _, row in self.data.iterrows():
            text = str(row[6]) + str(row[0]) + str(row[7])
            text = clean_text(text)
            yield word_tokenize(text)
train_sentences = TrainCorpus(train_data.head(80000))
w2v_model = Word2Vec(vector_size=100, min_count=10)
w2v_model.build_vocab(corpus_iterable=train_sentences)

key_to_index = w2v_model.wv.key_to_index
index_to_key = w2v_model.wv.index_to_key

index_to_key.append('<unk>')
key_to_index['<unk>'] = len(index_to_key) - 1

vocab_size = len(index_to_key)
print(vocab_size)
81477
class TrainDataset(torch.utils.data.IterableDataset):
    def __init__(self, data, index_to_key, key_to_index, reversed=False):
        self.reversed = reversed
        self.data = data
        self.index_to_key = index_to_key
        self.key_to_index = key_to_index
        self.vocab_size = len(key_to_index)

    def __iter__(self):
        for _, row in self.data.iterrows():
            text = str(row[6]) + str(row[0]) + str(row[7])
            text = clean_text(text)
            tokens = word_tokenize(text)
            if self.reversed:
                tokens = list(reversed(tokens))
            for i in range(5, len(tokens), 1):
                input_context = tokens[i-5:i]
                target_context = tokens[i-4:i+1]
                #gap_word = tokens[i]
            
                input_embed = [self.key_to_index[word] if word in self.key_to_index else self.key_to_index['<unk>'] for word in input_context]
                target_embed = [self.key_to_index[word] if word in self.key_to_index else self.key_to_index['<unk>'] for word in target_context]
                #word_index = self.key_to_index[gap_word] if gap_word in self.key_to_index else self.key_to_index['<unk>']
                #word_embed = np.concatenate([np.zeros(word_index), np.ones(1), np.zeros(vocab_size - word_index - 1)])
                
                yield np.asarray(input_embed, dtype=np.int64), np.asarray(target_embed, dtype=np.int64)
class Model(nn.Module):
    def __init__(self, embed_size, vocab_size):
        super(Model, self).__init__()
        self.embed_size = embed_size
        self.vocab_size = vocab_size
        self.lstm_size = 128
        self.num_layers = 2
        
        self.embed = nn.Embedding(num_embeddings=vocab_size, embedding_dim=self.embed_size)
        self.lstm = nn.LSTM(input_size=self.embed_size, hidden_size=self.lstm_size, num_layers=self.num_layers, dropout=0.2)
        self.fc = nn.Linear(self.lstm_size, vocab_size)

    def forward(self, x, prev_state = None):
        embed = self.embed(x)
        output, state = self.lstm(embed, prev_state)
        logits = self.fc(output)
        probs = torch.softmax(logits, dim=1)
        return logits, state

    def init_state(self, sequence_length):
        zeros = torch.zeros(self.num_layers, sequence_length, self.gru_size).to(device)
        return (zeros, zeros)
from torch.utils.data import DataLoader
from torch.optim import Adam

def train(dataset, model, max_epochs, batch_size):
    model.train()

    dataloader = DataLoader(dataset, batch_size=batch_size)
    criterion = nn.CrossEntropyLoss()
    optimizer = Adam(model.parameters(), lr=0.001)

    for epoch in range(max_epochs):
        for batch, (x, y) in enumerate(dataloader):
            optimizer.zero_grad()
            
            x = x.to(device)
            y = y.to(device)
            
            y_pred, (state_h, state_c) = model(x)
            loss = criterion(y_pred.transpose(1, 2), y)

            loss.backward()
            optimizer.step()
            
            if batch % 1000 == 0:
                print(f'epoch: {epoch}, update in batch {batch}/???, loss: {loss.item()}')
train_dataset_front = TrainDataset(train_data.head(80000), index_to_key, key_to_index, False)
train_dataset_back = TrainDataset(train_data.tail(80000), index_to_key, key_to_index, True)
model_front = Model(100, vocab_size).to(device)
model_back = Model(100, vocab_size).to(device)
train(train_dataset_front, model_front, 1, 64)
epoch: 0, update in batch 0/???, loss: 11.314821243286133
epoch: 0, update in batch 1000/???, loss: 6.876476287841797
epoch: 0, update in batch 2000/???, loss: 7.133523464202881
epoch: 0, update in batch 3000/???, loss: 6.979971885681152
epoch: 0, update in batch 4000/???, loss: 7.018368721008301
epoch: 0, update in batch 5000/???, loss: 6.494096279144287
epoch: 0, update in batch 6000/???, loss: 6.448479652404785
epoch: 0, update in batch 7000/???, loss: 6.526387691497803
epoch: 0, update in batch 8000/???, loss: 6.536323547363281
epoch: 0, update in batch 9000/???, loss: 6.4919538497924805
epoch: 0, update in batch 10000/???, loss: 6.435188293457031
epoch: 0, update in batch 11000/???, loss: 6.934823513031006
epoch: 0, update in batch 12000/???, loss: 7.410381317138672
epoch: 0, update in batch 13000/???, loss: 8.227864265441895
epoch: 0, update in batch 14000/???, loss: 6.7139105796813965
epoch: 0, update in batch 15000/???, loss: 6.82781457901001
epoch: 0, update in batch 16000/???, loss: 6.637822151184082
epoch: 0, update in batch 17000/???, loss: 6.2633233070373535
epoch: 0, update in batch 18000/???, loss: 6.512040138244629
epoch: 0, update in batch 19000/???, loss: 5.745478630065918
epoch: 0, update in batch 20000/???, loss: 7.039064884185791
epoch: 0, update in batch 21000/???, loss: 7.151158332824707
epoch: 0, update in batch 22000/???, loss: 6.460148811340332
epoch: 0, update in batch 23000/???, loss: 7.396632194519043
epoch: 0, update in batch 24000/???, loss: 5.907363414764404
epoch: 0, update in batch 25000/???, loss: 6.669890403747559
epoch: 0, update in batch 26000/???, loss: 6.032290458679199
epoch: 0, update in batch 27000/???, loss: 6.192468166351318
epoch: 0, update in batch 28000/???, loss: 5.757508277893066
epoch: 0, update in batch 29000/???, loss: 7.097552299499512
epoch: 0, update in batch 30000/???, loss: 6.8356804847717285
epoch: 0, update in batch 31000/???, loss: 4.938998699188232
epoch: 0, update in batch 32000/???, loss: 6.34550142288208
epoch: 0, update in batch 33000/???, loss: 7.154759883880615
epoch: 0, update in batch 34000/???, loss: 6.8563055992126465
epoch: 0, update in batch 35000/???, loss: 6.831148624420166
epoch: 0, update in batch 36000/???, loss: 6.867754936218262
epoch: 0, update in batch 37000/???, loss: 6.911463260650635
epoch: 0, update in batch 38000/???, loss: 6.637528896331787
epoch: 0, update in batch 39000/???, loss: 6.822340488433838
epoch: 0, update in batch 40000/???, loss: 6.122499942779541
epoch: 0, update in batch 41000/???, loss: 6.454296112060547
epoch: 0, update in batch 42000/???, loss: 7.5895185470581055
epoch: 0, update in batch 43000/???, loss: 5.775805473327637
epoch: 0, update in batch 44000/???, loss: 5.973118305206299
epoch: 0, update in batch 45000/???, loss: 5.7727460861206055
epoch: 0, update in batch 46000/???, loss: 6.376847267150879
epoch: 0, update in batch 47000/???, loss: 5.739894866943359
epoch: 0, update in batch 48000/???, loss: 6.390743732452393
epoch: 0, update in batch 49000/???, loss: 7.724233150482178
epoch: 0, update in batch 50000/???, loss: 5.242608070373535
epoch: 0, update in batch 51000/???, loss: 5.412053108215332
epoch: 0, update in batch 52000/???, loss: 6.590373992919922
epoch: 0, update in batch 53000/???, loss: 6.46323299407959
epoch: 0, update in batch 54000/???, loss: 6.9850263595581055
epoch: 0, update in batch 55000/???, loss: 7.3167219161987305
epoch: 0, update in batch 56000/???, loss: 6.285423278808594
epoch: 0, update in batch 57000/???, loss: 7.417998313903809
epoch: 0, update in batch 58000/???, loss: 6.437861442565918
epoch: 0, update in batch 59000/???, loss: 6.522177219390869
epoch: 0, update in batch 60000/???, loss: 5.9156928062438965
epoch: 0, update in batch 61000/???, loss: 4.946429252624512
epoch: 0, update in batch 62000/???, loss: 6.633675575256348
epoch: 0, update in batch 63000/???, loss: 7.357038974761963
epoch: 0, update in batch 64000/???, loss: 5.774768352508545
epoch: 0, update in batch 65000/???, loss: 6.289044380187988
epoch: 0, update in batch 66000/???, loss: 6.127488136291504
epoch: 0, update in batch 67000/???, loss: 5.059685230255127
epoch: 0, update in batch 68000/???, loss: 6.5439910888671875
epoch: 0, update in batch 69000/???, loss: 6.679286956787109
epoch: 0, update in batch 70000/???, loss: 7.2232346534729
epoch: 0, update in batch 71000/???, loss: 6.13685941696167
epoch: 0, update in batch 72000/???, loss: 5.766592025756836
epoch: 0, update in batch 73000/???, loss: 6.772070407867432
epoch: 0, update in batch 74000/???, loss: 7.369122505187988
epoch: 0, update in batch 75000/???, loss: 6.598935127258301
epoch: 0, update in batch 76000/???, loss: 5.948511600494385
epoch: 0, update in batch 77000/???, loss: 6.507765769958496
epoch: 0, update in batch 78000/???, loss: 5.09373664855957
epoch: 0, update in batch 79000/???, loss: 5.9862494468688965
epoch: 0, update in batch 80000/???, loss: 6.106108665466309
epoch: 0, update in batch 81000/???, loss: 5.2747578620910645
epoch: 0, update in batch 82000/???, loss: 6.324326515197754
epoch: 0, update in batch 83000/???, loss: 5.914392471313477
epoch: 0, update in batch 84000/???, loss: 6.641409873962402
epoch: 0, update in batch 85000/???, loss: 6.287321090698242
epoch: 0, update in batch 86000/???, loss: 6.510883331298828
epoch: 0, update in batch 87000/???, loss: 6.458550930023193
epoch: 0, update in batch 88000/???, loss: 6.07730770111084
epoch: 0, update in batch 89000/???, loss: 6.2387471199035645
epoch: 0, update in batch 90000/???, loss: 5.63344669342041
epoch: 0, update in batch 91000/???, loss: 6.277956962585449
epoch: 0, update in batch 92000/???, loss: 6.841054439544678
epoch: 0, update in batch 93000/???, loss: 6.458809852600098
epoch: 0, update in batch 94000/???, loss: 7.471741676330566
epoch: 0, update in batch 95000/???, loss: 6.461136817932129
epoch: 0, update in batch 96000/???, loss: 5.718675136566162
epoch: 0, update in batch 97000/???, loss: 4.4265007972717285
epoch: 0, update in batch 98000/???, loss: 7.05142879486084
epoch: 0, update in batch 99000/???, loss: 6.341854572296143
epoch: 0, update in batch 100000/???, loss: 6.834918022155762
epoch: 0, update in batch 101000/???, loss: 5.367598056793213
epoch: 0, update in batch 102000/???, loss: 5.716221809387207
epoch: 0, update in batch 103000/???, loss: 6.9465742111206055
epoch: 0, update in batch 104000/???, loss: 5.976019382476807
epoch: 0, update in batch 105000/???, loss: 6.125661849975586
epoch: 0, update in batch 106000/???, loss: 6.724229335784912
epoch: 0, update in batch 107000/???, loss: 6.446004390716553
epoch: 0, update in batch 108000/???, loss: 6.4710845947265625
epoch: 0, update in batch 109000/???, loss: 6.5926103591918945
epoch: 0, update in batch 110000/???, loss: 6.966839790344238
epoch: 0, update in batch 111000/???, loss: 7.263918876647949
epoch: 0, update in batch 112000/???, loss: 6.7561750411987305
epoch: 0, update in batch 113000/???, loss: 6.142555236816406
epoch: 0, update in batch 114000/???, loss: 5.974082946777344
epoch: 0, update in batch 115000/???, loss: 5.565796852111816
epoch: 0, update in batch 116000/???, loss: 6.4826202392578125
epoch: 0, update in batch 117000/???, loss: 5.643266201019287
epoch: 0, update in batch 118000/???, loss: 6.360909461975098
epoch: 0, update in batch 119000/???, loss: 5.4074201583862305
epoch: 0, update in batch 120000/???, loss: 7.1339569091796875
epoch: 0, update in batch 121000/???, loss: 6.786561012268066
epoch: 0, update in batch 122000/???, loss: 6.329574108123779
epoch: 0, update in batch 123000/???, loss: 7.21968936920166
epoch: 0, update in batch 124000/???, loss: 5.351359844207764
epoch: 0, update in batch 125000/???, loss: 7.962380886077881
epoch: 0, update in batch 126000/???, loss: 6.351782321929932
epoch: 0, update in batch 127000/???, loss: 6.8343048095703125
epoch: 0, update in batch 128000/???, loss: 6.129800319671631
epoch: 0, update in batch 129000/???, loss: 6.68627405166626
epoch: 0, update in batch 130000/???, loss: 6.498664855957031
epoch: 0, update in batch 131000/???, loss: 5.724549293518066
epoch: 0, update in batch 132000/???, loss: 7.041095733642578
epoch: 0, update in batch 133000/???, loss: 5.901988983154297
epoch: 0, update in batch 134000/???, loss: 6.055495262145996
epoch: 0, update in batch 135000/???, loss: 6.363399982452393
epoch: 0, update in batch 136000/???, loss: 7.45733642578125
epoch: 0, update in batch 137000/???, loss: 6.960203647613525
epoch: 0, update in batch 138000/???, loss: 6.986503601074219
epoch: 0, update in batch 139000/???, loss: 5.7938127517700195
epoch: 0, update in batch 140000/???, loss: 5.559916019439697
epoch: 0, update in batch 141000/???, loss: 5.551616668701172
epoch: 0, update in batch 142000/???, loss: 5.386819839477539
epoch: 0, update in batch 143000/???, loss: 6.826618194580078
epoch: 0, update in batch 144000/???, loss: 6.106345176696777
epoch: 0, update in batch 145000/???, loss: 6.812024116516113
epoch: 0, update in batch 146000/???, loss: 6.347486972808838
epoch: 0, update in batch 147000/???, loss: 6.20189094543457
epoch: 0, update in batch 148000/???, loss: 5.5717034339904785
epoch: 0, update in batch 149000/???, loss: 6.884232521057129
epoch: 0, update in batch 150000/???, loss: 6.8074846267700195
epoch: 0, update in batch 151000/???, loss: 7.028794288635254
epoch: 0, update in batch 152000/???, loss: 5.201214790344238
epoch: 0, update in batch 153000/???, loss: 5.1864013671875
epoch: 0, update in batch 154000/???, loss: 6.4473114013671875
epoch: 0, update in batch 155000/???, loss: 4.9203643798828125
epoch: 0, update in batch 156000/???, loss: 6.829309940338135
epoch: 0, update in batch 157000/???, loss: 7.045801639556885
epoch: 0, update in batch 158000/???, loss: 6.4073967933654785
epoch: 0, update in batch 159000/???, loss: 6.494145393371582
epoch: 0, update in batch 160000/???, loss: 6.682474613189697
epoch: 0, update in batch 161000/???, loss: 5.125617980957031
epoch: 0, update in batch 162000/???, loss: 5.915367126464844
epoch: 0, update in batch 163000/???, loss: 6.4779157638549805
epoch: 0, update in batch 164000/???, loss: 5.547584533691406
epoch: 0, update in batch 165000/???, loss: 6.134579181671143
epoch: 0, update in batch 166000/???, loss: 5.300144672393799
epoch: 0, update in batch 167000/???, loss: 6.53488826751709
epoch: 0, update in batch 168000/???, loss: 6.711917877197266
epoch: 0, update in batch 169000/???, loss: 7.0150322914123535
epoch: 0, update in batch 170000/???, loss: 5.681846618652344
epoch: 0, update in batch 171000/???, loss: 6.583130836486816
epoch: 0, update in batch 172000/???, loss: 6.411820411682129
epoch: 0, update in batch 173000/???, loss: 5.725490093231201
epoch: 0, update in batch 174000/???, loss: 6.651374816894531
epoch: 0, update in batch 175000/???, loss: 5.800152778625488
epoch: 0, update in batch 176000/???, loss: 6.862998962402344
epoch: 0, update in batch 177000/???, loss: 6.668658256530762
epoch: 0, update in batch 178000/???, loss: 6.519270896911621
epoch: 0, update in batch 179000/???, loss: 6.716788291931152
epoch: 0, update in batch 180000/???, loss: 6.675846099853516
epoch: 0, update in batch 181000/???, loss: 6.598060607910156
epoch: 0, update in batch 182000/???, loss: 6.638599395751953
epoch: 0, update in batch 183000/???, loss: 5.693145275115967
epoch: 0, update in batch 184000/???, loss: 5.175653457641602
epoch: 0, update in batch 185000/???, loss: 6.659600734710693
epoch: 0, update in batch 186000/???, loss: 5.782421112060547
epoch: 0, update in batch 187000/???, loss: 6.1736297607421875
epoch: 0, update in batch 188000/???, loss: 5.38541316986084
epoch: 0, update in batch 189000/???, loss: 6.238187789916992
epoch: 0, update in batch 190000/???, loss: 6.10030460357666
epoch: 0, update in batch 191000/???, loss: 6.680960655212402
epoch: 0, update in batch 192000/???, loss: 6.600944519042969
epoch: 0, update in batch 193000/???, loss: 6.171700477600098
epoch: 0, update in batch 194000/???, loss: 7.250021934509277
epoch: 0, update in batch 195000/???, loss: 5.968771934509277
epoch: 0, update in batch 196000/???, loss: 7.107605934143066
epoch: 0, update in batch 197000/???, loss: 6.743283748626709
epoch: 0, update in batch 198000/???, loss: 7.130635738372803
epoch: 0, update in batch 199000/???, loss: 6.37470817565918
epoch: 0, update in batch 200000/???, loss: 6.050590515136719
epoch: 0, update in batch 201000/???, loss: 5.468177318572998
epoch: 0, update in batch 202000/???, loss: 6.343471527099609
epoch: 0, update in batch 203000/???, loss: 6.890538692474365
epoch: 0, update in batch 204000/???, loss: 7.018721580505371
epoch: 0, update in batch 205000/???, loss: 6.131939888000488
epoch: 0, update in batch 206000/???, loss: 6.219918251037598
epoch: 0, update in batch 207000/???, loss: 5.858460426330566
epoch: 0, update in batch 208000/???, loss: 6.33021354675293
epoch: 0, update in batch 209000/???, loss: 6.249329566955566
epoch: 0, update in batch 210000/???, loss: 6.263474941253662
epoch: 0, update in batch 211000/???, loss: 6.731234550476074
epoch: 0, update in batch 212000/???, loss: 5.978096961975098
epoch: 0, update in batch 213000/???, loss: 5.148629188537598
epoch: 0, update in batch 214000/???, loss: 6.79285192489624
epoch: 0, update in batch 215000/???, loss: 5.943106651306152
epoch: 0, update in batch 216000/???, loss: 5.749272346496582
epoch: 0, update in batch 217000/???, loss: 6.991009712219238
epoch: 0, update in batch 218000/???, loss: 6.21205997467041
epoch: 0, update in batch 219000/???, loss: 7.519427299499512
epoch: 0, update in batch 220000/???, loss: 5.699267387390137
epoch: 0, update in batch 221000/???, loss: 6.05304479598999
epoch: 0, update in batch 222000/???, loss: 6.422593116760254
epoch: 0, update in batch 223000/???, loss: 6.179877281188965
epoch: 0, update in batch 224000/???, loss: 4.841546058654785
epoch: 0, update in batch 225000/???, loss: 6.666176795959473
epoch: 0, update in batch 226000/???, loss: 5.994054794311523
epoch: 0, update in batch 227000/???, loss: 6.792928218841553
epoch: 0, update in batch 228000/???, loss: 6.9571661949157715
epoch: 0, update in batch 229000/???, loss: 6.198942184448242
epoch: 0, update in batch 230000/???, loss: 5.944539546966553
epoch: 0, update in batch 231000/???, loss: 6.188899040222168
epoch: 0, update in batch 232000/???, loss: 5.826596260070801
epoch: 0, update in batch 233000/???, loss: 5.728386878967285
epoch: 0, update in batch 234000/???, loss: 7.6024885177612305
epoch: 0, update in batch 235000/???, loss: 6.728615760803223
epoch: 0, update in batch 236000/???, loss: 6.2461137771606445
epoch: 0, update in batch 237000/???, loss: 6.3110551834106445
epoch: 0, update in batch 238000/???, loss: 6.12617826461792
epoch: 0, update in batch 239000/???, loss: 6.6068243980407715
epoch: 0, update in batch 240000/???, loss: 7.015429496765137
epoch: 0, update in batch 241000/???, loss: 8.444561004638672
epoch: 0, update in batch 242000/???, loss: 7.289303779602051
epoch: 0, update in batch 243000/???, loss: 6.260491371154785
epoch: 0, update in batch 244000/???, loss: 7.60237979888916
epoch: 0, update in batch 245000/???, loss: 6.295613765716553
epoch: 0, update in batch 246000/???, loss: 5.929107666015625
epoch: 0, update in batch 247000/???, loss: 5.835566997528076
epoch: 0, update in batch 248000/???, loss: 5.837784290313721
epoch: 0, update in batch 249000/???, loss: 5.972233772277832
epoch: 0, update in batch 250000/???, loss: 6.0488996505737305
epoch: 0, update in batch 251000/???, loss: 5.712280750274658
epoch: 0, update in batch 252000/???, loss: 5.9513702392578125
epoch: 0, update in batch 253000/???, loss: 5.636294364929199
epoch: 0, update in batch 254000/???, loss: 5.91803503036499
epoch: 0, update in batch 255000/???, loss: 7.285937309265137
epoch: 0, update in batch 256000/???, loss: 6.4795637130737305
epoch: 0, update in batch 257000/???, loss: 6.0709991455078125
epoch: 0, update in batch 258000/???, loss: 5.8723649978637695
epoch: 0, update in batch 259000/???, loss: 5.174002647399902
epoch: 0, update in batch 260000/???, loss: 6.504033088684082
epoch: 0, update in batch 261000/???, loss: 7.088961601257324
epoch: 0, update in batch 262000/???, loss: 6.2242960929870605
epoch: 0, update in batch 263000/???, loss: 5.970286846160889
epoch: 0, update in batch 264000/???, loss: 5.961676597595215
epoch: 0, update in batch 265000/???, loss: 6.170080661773682
epoch: 0, update in batch 266000/???, loss: 5.477972507476807
epoch: 0, update in batch 267000/???, loss: 6.188825607299805
epoch: 0, update in batch 268000/???, loss: 6.518698215484619
epoch: 0, update in batch 269000/???, loss: 5.663434028625488
epoch: 0, update in batch 270000/???, loss: 5.978742599487305
epoch: 0, update in batch 271000/???, loss: 6.217379093170166
epoch: 0, update in batch 272000/???, loss: 5.426600933074951
epoch: 0, update in batch 273000/???, loss: 6.7220964431762695
epoch: 0, update in batch 274000/???, loss: 4.276306629180908
epoch: 0, update in batch 275000/???, loss: 5.420112609863281
epoch: 0, update in batch 276000/???, loss: 5.934456825256348
epoch: 0, update in batch 277000/???, loss: 7.186459541320801
epoch: 0, update in batch 278000/???, loss: 6.126835823059082
epoch: 0, update in batch 279000/???, loss: 5.727339267730713
epoch: 0, update in batch 280000/???, loss: 5.725864410400391
epoch: 0, update in batch 281000/???, loss: 5.47005033493042
epoch: 0, update in batch 282000/???, loss: 6.217499732971191
epoch: 0, update in batch 283000/???, loss: 6.022196292877197
epoch: 0, update in batch 284000/???, loss: 5.932379722595215
epoch: 0, update in batch 285000/???, loss: 6.321987628936768
epoch: 0, update in batch 286000/???, loss: 7.480570316314697
epoch: 0, update in batch 287000/???, loss: 5.169373512268066
epoch: 0, update in batch 288000/???, loss: 6.301320552825928
epoch: 0, update in batch 289000/???, loss: 6.4635009765625
epoch: 0, update in batch 290000/???, loss: 6.8701887130737305
epoch: 0, update in batch 291000/???, loss: 6.036175727844238
epoch: 0, update in batch 292000/???, loss: 6.705732822418213
epoch: 0, update in batch 293000/???, loss: 6.99608850479126
epoch: 0, update in batch 294000/???, loss: 6.50225305557251
epoch: 0, update in batch 295000/???, loss: 6.03929328918457
epoch: 0, update in batch 296000/???, loss: 5.498082160949707
epoch: 0, update in batch 297000/???, loss: 6.04677677154541
epoch: 0, update in batch 298000/???, loss: 6.482898712158203
epoch: 0, update in batch 299000/???, loss: 7.235076904296875
epoch: 0, update in batch 300000/???, loss: 6.019383907318115
epoch: 0, update in batch 301000/???, loss: 7.082001686096191
epoch: 0, update in batch 302000/???, loss: 6.447659492492676
epoch: 0, update in batch 303000/???, loss: 5.94022798538208
epoch: 0, update in batch 304000/???, loss: 6.459266662597656
epoch: 0, update in batch 305000/???, loss: 6.281588077545166
epoch: 0, update in batch 306000/???, loss: 7.022011756896973
epoch: 0, update in batch 307000/???, loss: 6.1802263259887695
epoch: 0, update in batch 308000/???, loss: 4.189492225646973
epoch: 0, update in batch 309000/???, loss: 6.7040696144104
epoch: 0, update in batch 310000/???, loss: 6.589522361755371
epoch: 0, update in batch 311000/???, loss: 6.243889808654785
epoch: 0, update in batch 312000/???, loss: 5.490180015563965
epoch: 0, update in batch 313000/???, loss: 5.9699201583862305
epoch: 0, update in batch 314000/???, loss: 7.321981906890869
epoch: 0, update in batch 315000/???, loss: 4.731215953826904
epoch: 0, update in batch 316000/???, loss: 5.845946788787842
epoch: 0, update in batch 317000/???, loss: 5.917788505554199
epoch: 0, update in batch 318000/???, loss: 6.420014381408691
epoch: 0, update in batch 319000/???, loss: 6.550830841064453
epoch: 0, update in batch 320000/???, loss: 6.751360893249512
epoch: 0, update in batch 321000/???, loss: 5.025134086608887
epoch: 0, update in batch 322000/???, loss: 6.368621826171875
epoch: 0, update in batch 323000/???, loss: 6.2042083740234375
epoch: 0, update in batch 324000/???, loss: 6.173147678375244
epoch: 0, update in batch 325000/???, loss: 5.865999221801758
epoch: 0, update in batch 326000/???, loss: 6.844902992248535
epoch: 0, update in batch 327000/???, loss: 6.080742359161377
epoch: 0, update in batch 328000/???, loss: 5.41788387298584
epoch: 0, update in batch 329000/???, loss: 5.831374645233154
epoch: 0, update in batch 330000/???, loss: 6.4492506980896
epoch: 0, update in batch 331000/???, loss: 6.220627784729004
epoch: 0, update in batch 332000/???, loss: 5.880006313323975
epoch: 0, update in batch 333000/???, loss: 6.806972503662109
epoch: 0, update in batch 334000/???, loss: 7.165728569030762
epoch: 0, update in batch 335000/???, loss: 6.322948932647705
epoch: 0, update in batch 336000/???, loss: 6.206046104431152
epoch: 0, update in batch 337000/???, loss: 6.097958564758301
epoch: 0, update in batch 338000/???, loss: 6.7682952880859375
epoch: 0, update in batch 339000/???, loss: 5.2390642166137695
epoch: 0, update in batch 340000/???, loss: 6.913119316101074
train(train_dataset_back, model_back, 1, 64)
epoch: 0, update in batch 0/???, loss: 11.3253755569458
epoch: 0, update in batch 1000/???, loss: 5.709358215332031
epoch: 0, update in batch 2000/???, loss: 7.989391326904297
epoch: 0, update in batch 3000/???, loss: 6.578714847564697
epoch: 0, update in batch 4000/???, loss: 7.051873207092285
epoch: 0, update in batch 5000/???, loss: 6.85653018951416
epoch: 0, update in batch 6000/???, loss: 6.812790870666504
epoch: 0, update in batch 7000/???, loss: 6.9604010581970215
epoch: 0, update in batch 8000/???, loss: 6.798591613769531
epoch: 0, update in batch 9000/???, loss: 6.415241241455078
epoch: 0, update in batch 10000/???, loss: 6.6636223793029785
epoch: 0, update in batch 11000/???, loss: 6.593747138977051
epoch: 0, update in batch 12000/???, loss: 6.914702415466309
epoch: 0, update in batch 13000/???, loss: 5.542675971984863
epoch: 0, update in batch 14000/???, loss: 6.5461883544921875
epoch: 0, update in batch 15000/???, loss: 7.507067680358887
epoch: 0, update in batch 16000/???, loss: 5.425755500793457
epoch: 0, update in batch 17000/???, loss: 6.285205841064453
epoch: 0, update in batch 18000/???, loss: 4.223124027252197
epoch: 0, update in batch 19000/???, loss: 6.530254364013672
epoch: 0, update in batch 20000/???, loss: 6.091847896575928
epoch: 0, update in batch 21000/???, loss: 7.088344573974609
epoch: 0, update in batch 22000/???, loss: 5.925537109375
epoch: 0, update in batch 23000/???, loss: 6.3628082275390625
epoch: 0, update in batch 24000/???, loss: 6.604581356048584
epoch: 0, update in batch 25000/???, loss: 6.2706499099731445
epoch: 0, update in batch 26000/???, loss: 6.114742755889893
epoch: 0, update in batch 27000/???, loss: 5.686783790588379
epoch: 0, update in batch 28000/???, loss: 5.5114521980285645
epoch: 0, update in batch 29000/???, loss: 6.999403953552246
epoch: 0, update in batch 30000/???, loss: 5.834499359130859
epoch: 0, update in batch 31000/???, loss: 5.873156547546387
epoch: 0, update in batch 32000/???, loss: 6.246962547302246
epoch: 0, update in batch 33000/???, loss: 6.742733955383301
epoch: 0, update in batch 34000/???, loss: 6.832881927490234
epoch: 0, update in batch 35000/???, loss: 6.625868320465088
epoch: 0, update in batch 36000/???, loss: 6.653105735778809
epoch: 0, update in batch 37000/???, loss: 6.104651927947998
epoch: 0, update in batch 38000/???, loss: 6.301898002624512
epoch: 0, update in batch 39000/???, loss: 7.377936363220215
epoch: 0, update in batch 40000/???, loss: 6.26895809173584
epoch: 0, update in batch 41000/???, loss: 6.602926731109619
epoch: 0, update in batch 42000/???, loss: 6.419803619384766
epoch: 0, update in batch 43000/???, loss: 7.187136650085449
epoch: 0, update in batch 44000/???, loss: 6.382015705108643
epoch: 0, update in batch 45000/???, loss: 6.044090747833252
epoch: 0, update in batch 46000/???, loss: 5.707688808441162
epoch: 0, update in batch 47000/???, loss: 7.007757663726807
epoch: 0, update in batch 48000/???, loss: 5.365390300750732
epoch: 0, update in batch 49000/???, loss: 5.510242938995361
epoch: 0, update in batch 50000/???, loss: 5.955991268157959
epoch: 0, update in batch 51000/???, loss: 6.2313032150268555
epoch: 0, update in batch 52000/???, loss: 8.19306468963623
epoch: 0, update in batch 53000/???, loss: 6.345375061035156
epoch: 0, update in batch 54000/???, loss: 7.044759273529053
epoch: 0, update in batch 55000/???, loss: 6.2544779777526855
epoch: 0, update in batch 56000/???, loss: 6.315605163574219
epoch: 0, update in batch 57000/???, loss: 5.632706642150879
epoch: 0, update in batch 58000/???, loss: 6.0897536277771
epoch: 0, update in batch 59000/???, loss: 5.562952518463135
epoch: 0, update in batch 60000/???, loss: 5.519134044647217
epoch: 0, update in batch 61000/???, loss: 6.394771099090576
epoch: 0, update in batch 62000/???, loss: 6.147246360778809
epoch: 0, update in batch 63000/???, loss: 5.798914909362793
epoch: 0, update in batch 64000/???, loss: 6.026059627532959
epoch: 0, update in batch 65000/???, loss: 6.4533233642578125
epoch: 0, update in batch 66000/???, loss: 6.383795738220215
epoch: 0, update in batch 67000/???, loss: 6.466322898864746
epoch: 0, update in batch 68000/???, loss: 6.8227715492248535
epoch: 0, update in batch 69000/???, loss: 6.283398151397705
epoch: 0, update in batch 70000/???, loss: 4.547608375549316
epoch: 0, update in batch 71000/???, loss: 6.008975028991699
epoch: 0, update in batch 72000/???, loss: 5.674825191497803
epoch: 0, update in batch 73000/???, loss: 5.134644508361816
epoch: 0, update in batch 74000/???, loss: 6.906868934631348
epoch: 0, update in batch 75000/???, loss: 6.672898292541504
epoch: 0, update in batch 76000/???, loss: 5.813290596008301
epoch: 0, update in batch 77000/???, loss: 6.296219825744629
epoch: 0, update in batch 78000/???, loss: 6.531443119049072
epoch: 0, update in batch 79000/???, loss: 6.437461853027344
epoch: 0, update in batch 80000/???, loss: 6.2280778884887695
epoch: 0, update in batch 81000/???, loss: 6.805241584777832
epoch: 0, update in batch 82000/???, loss: 7.044824123382568
epoch: 0, update in batch 83000/???, loss: 7.348274230957031
epoch: 0, update in batch 84000/???, loss: 5.826806545257568
epoch: 0, update in batch 85000/???, loss: 5.474950313568115
epoch: 0, update in batch 86000/???, loss: 6.497323036193848
epoch: 0, update in batch 87000/???, loss: 5.88934850692749
epoch: 0, update in batch 88000/???, loss: 5.371798038482666
epoch: 0, update in batch 89000/???, loss: 6.093968391418457
epoch: 0, update in batch 90000/???, loss: 6.115981578826904
epoch: 0, update in batch 91000/???, loss: 6.504927158355713
epoch: 0, update in batch 92000/???, loss: 6.239808082580566
epoch: 0, update in batch 93000/???, loss: 5.384994983673096
epoch: 0, update in batch 94000/???, loss: 6.422779083251953
epoch: 0, update in batch 95000/???, loss: 7.163965702056885
epoch: 0, update in batch 96000/???, loss: 6.44806432723999
epoch: 0, update in batch 97000/???, loss: 6.153664588928223
epoch: 0, update in batch 98000/???, loss: 5.9013776779174805
epoch: 0, update in batch 99000/???, loss: 6.198166847229004
epoch: 0, update in batch 100000/???, loss: 5.752341270446777
epoch: 0, update in batch 101000/???, loss: 6.455883979797363
epoch: 0, update in batch 102000/???, loss: 5.270313262939453
epoch: 0, update in batch 103000/???, loss: 6.475237846374512
epoch: 0, update in batch 104000/???, loss: 6.2444844245910645
epoch: 0, update in batch 105000/???, loss: 6.1563720703125
epoch: 0, update in batch 106000/???, loss: 6.12777853012085
epoch: 0, update in batch 107000/???, loss: 6.449145317077637
epoch: 0, update in batch 108000/???, loss: 6.515239715576172
epoch: 0, update in batch 109000/???, loss: 5.6317644119262695
epoch: 0, update in batch 110000/???, loss: 6.09606409072876
epoch: 0, update in batch 111000/???, loss: 7.069797515869141
epoch: 0, update in batch 112000/???, loss: 7.456076145172119
epoch: 0, update in batch 113000/???, loss: 6.668386936187744
epoch: 0, update in batch 114000/???, loss: 7.705430507659912
epoch: 0, update in batch 115000/???, loss: 6.983656883239746
epoch: 0, update in batch 116000/???, loss: 6.320417404174805
epoch: 0, update in batch 117000/???, loss: 7.184473991394043
epoch: 0, update in batch 118000/???, loss: 6.603268623352051
epoch: 0, update in batch 119000/???, loss: 6.670085906982422
epoch: 0, update in batch 120000/???, loss: 6.748586177825928
epoch: 0, update in batch 121000/???, loss: 6.353959560394287
epoch: 0, update in batch 122000/???, loss: 5.138751029968262
epoch: 0, update in batch 123000/???, loss: 6.507109642028809
epoch: 0, update in batch 124000/???, loss: 6.360246181488037
epoch: 0, update in batch 125000/???, loss: 7.164086818695068
epoch: 0, update in batch 126000/???, loss: 5.610747337341309
epoch: 0, update in batch 127000/???, loss: 5.066179275512695
epoch: 0, update in batch 128000/???, loss: 5.688697814941406
epoch: 0, update in batch 129000/???, loss: 6.960330963134766
epoch: 0, update in batch 130000/???, loss: 5.818534851074219
epoch: 0, update in batch 131000/???, loss: 6.186715602874756
epoch: 0, update in batch 132000/???, loss: 5.825492858886719
epoch: 0, update in batch 133000/???, loss: 5.576340675354004
epoch: 0, update in batch 134000/???, loss: 5.503821849822998
epoch: 0, update in batch 135000/???, loss: 6.428965091705322
epoch: 0, update in batch 136000/???, loss: 5.102448463439941
epoch: 0, update in batch 137000/???, loss: 6.239314556121826
epoch: 0, update in batch 138000/???, loss: 6.028595447540283
epoch: 0, update in batch 139000/???, loss: 6.407244682312012
epoch: 0, update in batch 140000/???, loss: 5.597055912017822
epoch: 0, update in batch 141000/???, loss: 5.823704719543457
epoch: 0, update in batch 142000/???, loss: 6.665535926818848
epoch: 0, update in batch 143000/???, loss: 5.5736894607543945
epoch: 0, update in batch 144000/???, loss: 6.723180294036865
epoch: 0, update in batch 145000/???, loss: 6.378345489501953
epoch: 0, update in batch 146000/???, loss: 5.6936845779418945
epoch: 0, update in batch 147000/???, loss: 5.761658668518066
epoch: 0, update in batch 148000/???, loss: 5.580254077911377
epoch: 0, update in batch 149000/???, loss: 5.733176231384277
epoch: 0, update in batch 150000/???, loss: 6.901691436767578
epoch: 0, update in batch 151000/???, loss: 6.5111589431762695
epoch: 0, update in batch 152000/???, loss: 6.184727668762207
epoch: 0, update in batch 153000/???, loss: 7.407107353210449
epoch: 0, update in batch 154000/???, loss: 6.499199867248535
epoch: 0, update in batch 155000/???, loss: 5.143393516540527
epoch: 0, update in batch 156000/???, loss: 7.60940408706665
epoch: 0, update in batch 157000/???, loss: 6.766045570373535
epoch: 0, update in batch 158000/???, loss: 5.268759727478027
epoch: 0, update in batch 159000/???, loss: 7.558129787445068
epoch: 0, update in batch 160000/???, loss: 8.016000747680664
epoch: 0, update in batch 161000/???, loss: 5.959166526794434
epoch: 0, update in batch 162000/???, loss: 5.499085426330566
epoch: 0, update in batch 163000/???, loss: 6.581662654876709
epoch: 0, update in batch 164000/???, loss: 6.681334495544434
epoch: 0, update in batch 165000/???, loss: 7.817207336425781
epoch: 0, update in batch 166000/???, loss: 6.524381160736084
epoch: 0, update in batch 167000/???, loss: 5.903864860534668
epoch: 0, update in batch 168000/???, loss: 5.6087260246276855
epoch: 0, update in batch 169000/???, loss: 5.742824554443359
epoch: 0, update in batch 170000/???, loss: 6.129671096801758
epoch: 0, update in batch 171000/???, loss: 5.879034519195557
epoch: 0, update in batch 172000/???, loss: 6.322129249572754
epoch: 0, update in batch 173000/???, loss: 6.805352210998535
epoch: 0, update in batch 174000/???, loss: 7.162431240081787
epoch: 0, update in batch 175000/???, loss: 6.123959541320801
epoch: 0, update in batch 176000/???, loss: 7.544029235839844
epoch: 0, update in batch 177000/???, loss: 5.4254021644592285
epoch: 0, update in batch 178000/???, loss: 5.784268379211426
epoch: 0, update in batch 179000/???, loss: 5.8633856773376465
epoch: 0, update in batch 180000/???, loss: 6.556314945220947
epoch: 0, update in batch 181000/???, loss: 5.215446472167969
epoch: 0, update in batch 182000/???, loss: 6.079234600067139
epoch: 0, update in batch 183000/???, loss: 7.234827995300293
epoch: 0, update in batch 184000/???, loss: 5.249889373779297
epoch: 0, update in batch 185000/???, loss: 5.083311080932617
epoch: 0, update in batch 186000/???, loss: 6.061867713928223
epoch: 0, update in batch 187000/???, loss: 6.060431480407715
epoch: 0, update in batch 188000/???, loss: 5.572680950164795
epoch: 0, update in batch 189000/???, loss: 5.991988182067871
epoch: 0, update in batch 190000/???, loss: 6.521245002746582
epoch: 0, update in batch 191000/???, loss: 5.128615379333496
epoch: 0, update in batch 192000/???, loss: 5.616750717163086
epoch: 0, update in batch 193000/???, loss: 6.1465044021606445
epoch: 0, update in batch 194000/???, loss: 5.93985652923584
epoch: 0, update in batch 195000/???, loss: 6.268892765045166
epoch: 0, update in batch 196000/???, loss: 5.928576469421387
epoch: 0, update in batch 197000/???, loss: 5.257290363311768
epoch: 0, update in batch 198000/???, loss: 6.6432952880859375
epoch: 0, update in batch 199000/???, loss: 6.898074150085449
epoch: 0, update in batch 200000/???, loss: 7.042447566986084
epoch: 0, update in batch 201000/???, loss: 7.104043483734131
epoch: 0, update in batch 202000/???, loss: 6.238812446594238
epoch: 0, update in batch 203000/???, loss: 6.773525238037109
epoch: 0, update in batch 204000/???, loss: 5.054592132568359
epoch: 0, update in batch 205000/???, loss: 6.854428768157959
epoch: 0, update in batch 206000/???, loss: 5.9983601570129395
epoch: 0, update in batch 207000/???, loss: 5.236695766448975
epoch: 0, update in batch 208000/???, loss: 6.086891174316406
epoch: 0, update in batch 209000/???, loss: 6.134495258331299
epoch: 0, update in batch 210000/???, loss: 6.52248477935791
epoch: 0, update in batch 211000/???, loss: 6.028376579284668
epoch: 0, update in batch 212000/???, loss: 6.140281677246094
epoch: 0, update in batch 213000/???, loss: 6.066422462463379
epoch: 0, update in batch 214000/???, loss: 6.868189334869385
epoch: 0, update in batch 215000/???, loss: 6.641358852386475
epoch: 0, update in batch 216000/???, loss: 6.818638801574707
epoch: 0, update in batch 217000/???, loss: 6.40252685546875
epoch: 0, update in batch 218000/???, loss: 5.561617851257324
epoch: 0, update in batch 219000/???, loss: 6.434267997741699
epoch: 0, update in batch 220000/???, loss: 6.33272123336792
epoch: 0, update in batch 221000/???, loss: 5.75616979598999
epoch: 0, update in batch 222000/???, loss: 6.477814674377441
epoch: 0, update in batch 223000/???, loss: 5.259497165679932
epoch: 0, update in batch 224000/???, loss: 5.8639655113220215
epoch: 0, update in batch 225000/???, loss: 6.469706058502197
epoch: 0, update in batch 226000/???, loss: 5.707249164581299
epoch: 0, update in batch 227000/???, loss: 6.394181251525879
epoch: 0, update in batch 228000/???, loss: 5.048886299133301
epoch: 0, update in batch 229000/???, loss: 5.842928409576416
epoch: 0, update in batch 230000/???, loss: 5.627688407897949
epoch: 0, update in batch 231000/???, loss: 7.950299263000488
epoch: 0, update in batch 232000/???, loss: 6.771368503570557
epoch: 0, update in batch 233000/???, loss: 5.787235260009766
epoch: 0, update in batch 234000/???, loss: 5.6070780754089355
epoch: 0, update in batch 235000/???, loss: 6.060035705566406
epoch: 0, update in batch 236000/???, loss: 6.894829750061035
epoch: 0, update in batch 237000/???, loss: 5.672856330871582
epoch: 0, update in batch 238000/???, loss: 5.054213523864746
epoch: 0, update in batch 239000/???, loss: 6.484643459320068
epoch: 0, update in batch 240000/???, loss: 5.800728797912598
epoch: 0, update in batch 241000/???, loss: 5.148013591766357
epoch: 0, update in batch 242000/???, loss: 5.529184818267822
epoch: 0, update in batch 243000/???, loss: 5.959448337554932
epoch: 0, update in batch 244000/???, loss: 6.762448787689209
epoch: 0, update in batch 245000/???, loss: 4.907589912414551
epoch: 0, update in batch 246000/???, loss: 6.275182723999023
epoch: 0, update in batch 247000/???, loss: 5.7234015464782715
epoch: 0, update in batch 248000/???, loss: 6.119207859039307
epoch: 0, update in batch 249000/???, loss: 5.297057151794434
epoch: 0, update in batch 250000/???, loss: 5.924614906311035
epoch: 0, update in batch 251000/???, loss: 6.651083469390869
epoch: 0, update in batch 252000/???, loss: 5.7164201736450195
epoch: 0, update in batch 253000/???, loss: 6.105191230773926
epoch: 0, update in batch 254000/???, loss: 5.791018486022949
epoch: 0, update in batch 255000/???, loss: 6.659502983093262
epoch: 0, update in batch 256000/???, loss: 5.613073348999023
epoch: 0, update in batch 257000/???, loss: 7.501049041748047
epoch: 0, update in batch 258000/???, loss: 6.043797492980957
epoch: 0, update in batch 259000/???, loss: 7.3587327003479
epoch: 0, update in batch 260000/???, loss: 6.276612281799316
epoch: 0, update in batch 261000/???, loss: 6.445192813873291
epoch: 0, update in batch 262000/???, loss: 5.0266547203063965
epoch: 0, update in batch 263000/???, loss: 6.404935359954834
epoch: 0, update in batch 264000/???, loss: 6.5042290687561035
epoch: 0, update in batch 265000/???, loss: 6.880773067474365
epoch: 0, update in batch 266000/???, loss: 6.3690643310546875
epoch: 0, update in batch 267000/???, loss: 6.055562973022461
epoch: 0, update in batch 268000/???, loss: 5.796906471252441
epoch: 0, update in batch 269000/???, loss: 5.654962539672852
epoch: 0, update in batch 270000/???, loss: 6.574362277984619
epoch: 0, update in batch 271000/???, loss: 6.256768226623535
epoch: 0, update in batch 272000/???, loss: 6.8345208168029785
epoch: 0, update in batch 273000/???, loss: 6.066469669342041
epoch: 0, update in batch 274000/???, loss: 6.625809669494629
epoch: 0, update in batch 275000/???, loss: 4.762896537780762
epoch: 0, update in batch 276000/???, loss: 6.019833564758301
epoch: 0, update in batch 277000/???, loss: 6.227939605712891
epoch: 0, update in batch 278000/???, loss: 7.046879768371582
epoch: 0, update in batch 279000/???, loss: 6.068551540374756
epoch: 0, update in batch 280000/???, loss: 6.454771995544434
epoch: 0, update in batch 281000/???, loss: 3.9379985332489014
epoch: 0, update in batch 282000/???, loss: 5.615240097045898
epoch: 0, update in batch 283000/???, loss: 5.7963151931762695
epoch: 0, update in batch 284000/???, loss: 6.064437389373779
epoch: 0, update in batch 285000/???, loss: 6.668734073638916
epoch: 0, update in batch 286000/???, loss: 6.776829719543457
epoch: 0, update in batch 287000/???, loss: 6.170516014099121
epoch: 0, update in batch 288000/???, loss: 4.840399742126465
epoch: 0, update in batch 289000/???, loss: 6.333052635192871
epoch: 0, update in batch 290000/???, loss: 5.595047950744629
epoch: 0, update in batch 291000/???, loss: 6.594934940338135
epoch: 0, update in batch 292000/???, loss: 5.950274467468262
epoch: 0, update in batch 293000/???, loss: 6.123660087585449
epoch: 0, update in batch 294000/???, loss: 5.904355049133301
epoch: 0, update in batch 295000/???, loss: 5.8828630447387695
epoch: 0, update in batch 296000/???, loss: 5.604973316192627
epoch: 0, update in batch 297000/???, loss: 4.842469692230225
epoch: 0, update in batch 298000/???, loss: 5.862446308135986
epoch: 0, update in batch 299000/???, loss: 6.90258264541626
epoch: 0, update in batch 300000/???, loss: 5.941957950592041
epoch: 0, update in batch 301000/???, loss: 5.697750568389893
epoch: 0, update in batch 302000/???, loss: 5.973014831542969
epoch: 0, update in batch 303000/???, loss: 5.46022367477417
epoch: 0, update in batch 304000/???, loss: 6.5218095779418945
epoch: 0, update in batch 305000/???, loss: 6.392545700073242
epoch: 0, update in batch 306000/???, loss: 7.080249786376953
epoch: 0, update in batch 307000/???, loss: 6.355096817016602
epoch: 0, update in batch 308000/???, loss: 5.625491619110107
epoch: 0, update in batch 309000/???, loss: 6.805799961090088
epoch: 0, update in batch 310000/???, loss: 6.426385402679443
epoch: 0, update in batch 311000/???, loss: 5.727842807769775
epoch: 0, update in batch 312000/???, loss: 6.9111199378967285
epoch: 0, update in batch 313000/???, loss: 6.40056848526001
epoch: 0, update in batch 314000/???, loss: 6.145076751708984
epoch: 0, update in batch 315000/???, loss: 6.097104072570801
epoch: 0, update in batch 316000/???, loss: 5.39146089553833
epoch: 0, update in batch 317000/???, loss: 6.125569820404053
epoch: 0, update in batch 318000/???, loss: 6.533677577972412
epoch: 0, update in batch 319000/???, loss: 5.944211483001709
epoch: 0, update in batch 320000/???, loss: 6.542410850524902
epoch: 0, update in batch 321000/???, loss: 5.699315071105957
epoch: 0, update in batch 322000/???, loss: 6.251957893371582
epoch: 0, update in batch 323000/???, loss: 5.346350193023682
epoch: 0, update in batch 324000/???, loss: 5.603858470916748
epoch: 0, update in batch 325000/???, loss: 5.740134239196777
epoch: 0, update in batch 326000/???, loss: 5.575300693511963
epoch: 0, update in batch 327000/???, loss: 6.996762752532959
epoch: 0, update in batch 328000/???, loss: 6.28995418548584
epoch: 0, update in batch 329000/???, loss: 4.519123077392578
epoch: 0, update in batch 330000/???, loss: 5.9068121910095215
epoch: 0, update in batch 331000/???, loss: 6.61830997467041
epoch: 0, update in batch 332000/???, loss: 6.063097953796387
epoch: 0, update in batch 333000/???, loss: 6.419328212738037
epoch: 0, update in batch 334000/???, loss: 5.927584648132324
epoch: 0, update in batch 335000/???, loss: 5.527887344360352
epoch: 0, update in batch 336000/???, loss: 6.114096641540527
epoch: 0, update in batch 337000/???, loss: 5.9415082931518555
epoch: 0, update in batch 338000/???, loss: 5.288441181182861
epoch: 0, update in batch 339000/???, loss: 6.611715793609619
epoch: 0, update in batch 340000/???, loss: 6.770573616027832
def predict_probs(left_tokens, right_tokens):
    model_front.eval()
    model_back.eval()

    x_left = torch.tensor([[train_dataset_front.key_to_index[w] if w in key_to_index else train_dataset_front.key_to_index['<unk>'] for w in left_tokens]]).to(device)
    x_right = torch.tensor([[train_dataset_front.key_to_index[w] if w in key_to_index else train_dataset_front.key_to_index['<unk>'] for w in right_tokens]]).to(device)
    y_pred_left, (state_h_left, state_c_left) = model_front(x_left)
    y_pred_right, (state_h_right, state_c_right) = model_back(x_right)

    last_word_logits_left = y_pred_left[0][-1]
    last_word_logits_right = y_pred_right[0][-1]
    probs_left = torch.nn.functional.softmax(last_word_logits_left, dim=0).detach().cpu().numpy()
    probs_right = torch.nn.functional.softmax(last_word_logits_right, dim=0).detach().cpu().numpy()
    
    probs = [np.mean(k) for k in zip(probs_left, probs_right)]
    
    top_words = []
    for index in range(len(probs)):
        if len(top_words) < 30:
            top_words.append((probs[index], [index]))
        else:
            worst_word = None
            for word in top_words:
                if not worst_word:
                    worst_word = word
                else:
                    if word[0] < worst_word[0]:
                        worst_word = word
            if worst_word[0] < probs[index] and index != len(probs) - 1:
                top_words.remove(worst_word)
                top_words.append((probs[index], [index]))
                
    prediction = ''
    sum_prob = 0.0
    for word in top_words:
        sum_prob += word[0]
        word_index = word[0]
        word_text = index_to_key[word[1][0]]
        prediction += f'{word_text}:{word_index} '
    prediction += f':{1 - sum_prob}'
    
    return prediction
dev_data = pd.read_csv('dev-0/in.tsv.xz', sep='\t', on_bad_lines='skip', header=None, quoting=csv.QUOTE_NONE)
test_data = pd.read_csv('test-A/in.tsv.xz', sep='\t', on_bad_lines='skip', header=None, quoting=csv.QUOTE_NONE)
with open('dev-0/out.tsv', 'w') as file:
    for index, row in dev_data.iterrows():
        left_text = clean_text(str(row[6]))
        right_text = clean_text(str(row[7]))
        left_words = word_tokenize(left_text)
        right_words = word_tokenize(right_text)
        right_words.reverse()
        if len(left_words) < 6 or len(right_words) < 6:
            prediction = ':1.0'
        else:
            prediction = predict_probs(left_words[-5:], right_words[-5:])
        file.write(prediction + '\n')
with open('test-A/out.tsv', 'w') as file:
    for index, row in test_data.iterrows():
        left_text = clean_text(str(row[6]))
        right_text = clean_text(str(row[7]))
        left_words = word_tokenize(left_text)
        right_words = word_tokenize(right_text)
        right_words.reverse()
        if len(left_words) < 6 or len(right_words) < 6:
            prediction = ':1.0'
        else:
            prediction = predict_probs(left_words[-5:], right_words[-5:])
        file.write(prediction + '\n')