challenging-america-word-ga.../run.ipynb
2022-05-28 15:36:41 +02:00

42 KiB

import pandas as pd
import numpy as np
import regex as re
import csv
import torch
from torch import nn
from gensim.models import Word2Vec
from nltk.tokenize import word_tokenize
torch.cuda.empty_cache()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
def clean_text(text):
    text = text.lower().replace('-\\\\\\\\\\\\\\\\n', '').replace('\\\\\\\\\\\\\\\\n', ' ')
    text = re.sub(r'\p{P}', '', text)
    text = text.replace("'t", " not").replace("'s", " is").replace("'ll", " will").replace("'m", " am").replace("'ve", " have")

    return text
train_data = pd.read_csv('train/in.tsv.xz', sep='\t', on_bad_lines='skip', header=None, quoting=csv.QUOTE_NONE)
train_labels = pd.read_csv('train/expected.tsv', sep='\t', on_bad_lines='skip', header=None, quoting=csv.QUOTE_NONE)

train_data = train_data[[6, 7]]
train_data = pd.concat([train_data, train_labels], axis=1)
class TrainCorpus:
    def __init__(self, data):
        self.data = data
        
    def __iter__(self):
        for _, row in self.data.iterrows():
            text = str(row[6]) + str(row[0]) + str(row[7])
            text = clean_text(text)
            yield word_tokenize(text)
train_sentences = TrainCorpus(train_data.head(100000))
w2v_model = Word2Vec(vector_size=100, min_count=10)
w2v_model.build_vocab(corpus_iterable=train_sentences)

key_to_index = w2v_model.wv.key_to_index
index_to_key = w2v_model.wv.index_to_key

index_to_key.append('<unk>')
key_to_index['<unk>'] = len(index_to_key) - 1

vocab_size = len(index_to_key)
print(vocab_size)
97122
class TrainDataset(torch.utils.data.IterableDataset):
    def __init__(self, data, index_to_key, key_to_index):
        self.data = data
        self.index_to_key = index_to_key
        self.key_to_index = key_to_index
        self.vocab_size = len(key_to_index)

    def __iter__(self):
        for _, row in self.data.iterrows():
            text = str(row[6]) + str(row[0]) + str(row[7])
            text = clean_text(text)
            tokens = word_tokenize(text)
            for i in range(5, len(tokens), 1):
                input_context = tokens[i-5:i]
                target_context = tokens[i-4:i+1]
                #gap_word = tokens[i]
            
                input_embed = [self.key_to_index[word] if word in self.key_to_index else self.key_to_index['<unk>'] for word in input_context]
                target_embed = [self.key_to_index[word] if word in self.key_to_index else self.key_to_index['<unk>'] for word in target_context]
                #word_index = self.key_to_index[gap_word] if gap_word in self.key_to_index else self.key_to_index['<unk>']
                #word_embed = np.concatenate([np.zeros(word_index), np.ones(1), np.zeros(vocab_size - word_index - 1)])
                
                yield np.asarray(input_embed, dtype=np.int64), np.asarray(target_embed, dtype=np.int64)
class Model(nn.Module):
    def __init__(self, embed_size, vocab_size):
        super(Model, self).__init__()
        self.embed_size = embed_size
        self.vocab_size = vocab_size
        self.gru_size = 128
        self.num_layers = 2
        
        self.embed = nn.Embedding(num_embeddings=vocab_size, embedding_dim=self.embed_size)
        self.gru = nn.GRU(input_size=self.embed_size, hidden_size=self.gru_size, num_layers=self.num_layers, dropout=0.2)
        self.fc = nn.Linear(self.gru_size, vocab_size)

    def forward(self, x, prev_state = None):
        embed = self.embed(x)
        output, state = self.gru(embed, prev_state)
        logits = self.fc(output)
        probs = torch.softmax(logits, dim=1)
        return logits, state

    def init_state(self, sequence_length):
        zeros = torch.zeros(self.num_layers, sequence_length, self.gru_size).to(device)
        return (zeros, zeros)
from torch.utils.data import DataLoader
from torch.optim import Adam

def train(dataset, model, max_epochs, batch_size):
    model.train()

    dataloader = DataLoader(dataset, batch_size=batch_size)
    criterion = nn.CrossEntropyLoss()
    optimizer = Adam(model.parameters(), lr=0.001)

    for epoch in range(max_epochs):
        for batch, (x, y) in enumerate(dataloader):
            optimizer.zero_grad()
            
            x = x.to(device)
            y = y.to(device)
            
            y_pred, state_h = model(x)
            loss = criterion(y_pred.transpose(1, 2), y)

            loss.backward()
            optimizer.step()
            
            if batch % 1000 == 0:
                print(f'epoch: {epoch}, update in batch {batch}/???, loss: {loss.item()}')
train_dataset = TrainDataset(train_data.head(100000), index_to_key, key_to_index)
model = Model(100, vocab_size).to(device)
train(train_dataset, model, 1, 64)
epoch: 0, update in batch 0/???, loss: 11.47425365447998
epoch: 0, update in batch 1000/???, loss: 7.042205810546875
epoch: 0, update in batch 2000/???, loss: 7.235440731048584
epoch: 0, update in batch 3000/???, loss: 7.251269340515137
epoch: 0, update in batch 4000/???, loss: 6.944191932678223
epoch: 0, update in batch 5000/???, loss: 6.263372421264648
epoch: 0, update in batch 6000/???, loss: 6.181947231292725
epoch: 0, update in batch 7000/???, loss: 6.508013725280762
epoch: 0, update in batch 8000/???, loss: 6.658236026763916
epoch: 0, update in batch 9000/???, loss: 6.536279201507568
epoch: 0, update in batch 10000/???, loss: 6.802626609802246
epoch: 0, update in batch 11000/???, loss: 6.961029052734375
epoch: 0, update in batch 12000/???, loss: 7.713824272155762
epoch: 0, update in batch 13000/???, loss: 8.100411415100098
epoch: 0, update in batch 14000/???, loss: 6.457145690917969
epoch: 0, update in batch 15000/???, loss: 6.850286960601807
epoch: 0, update in batch 16000/???, loss: 6.794063568115234
epoch: 0, update in batch 17000/???, loss: 6.311314582824707
epoch: 0, update in batch 18000/???, loss: 6.611917018890381
epoch: 0, update in batch 19000/???, loss: 5.679810523986816
epoch: 0, update in batch 20000/???, loss: 7.110655307769775
epoch: 0, update in batch 21000/???, loss: 7.170722961425781
epoch: 0, update in batch 22000/???, loss: 6.584908485412598
epoch: 0, update in batch 23000/???, loss: 7.224095344543457
epoch: 0, update in batch 24000/???, loss: 5.827445983886719
epoch: 0, update in batch 25000/???, loss: 6.444586753845215
epoch: 0, update in batch 26000/???, loss: 6.149054527282715
epoch: 0, update in batch 27000/???, loss: 6.259871482849121
epoch: 0, update in batch 28000/???, loss: 5.789839744567871
epoch: 0, update in batch 29000/???, loss: 7.025563716888428
epoch: 0, update in batch 30000/???, loss: 7.265492916107178
epoch: 0, update in batch 31000/???, loss: 4.921586036682129
epoch: 0, update in batch 32000/???, loss: 6.467754364013672
epoch: 0, update in batch 33000/???, loss: 7.393715858459473
epoch: 0, update in batch 34000/???, loss: 6.9696760177612305
epoch: 0, update in batch 35000/???, loss: 7.276318550109863
epoch: 0, update in batch 36000/???, loss: 7.011231899261475
epoch: 0, update in batch 37000/???, loss: 7.029260158538818
epoch: 0, update in batch 38000/???, loss: 6.723126411437988
epoch: 0, update in batch 39000/???, loss: 6.828773498535156
epoch: 0, update in batch 40000/???, loss: 6.069770336151123
epoch: 0, update in batch 41000/???, loss: 6.651298522949219
epoch: 0, update in batch 42000/???, loss: 7.455380916595459
epoch: 0, update in batch 43000/???, loss: 5.594773769378662
epoch: 0, update in batch 44000/???, loss: 6.102865219116211
epoch: 0, update in batch 45000/???, loss: 6.04202127456665
epoch: 0, update in batch 46000/???, loss: 6.472177982330322
epoch: 0, update in batch 47000/???, loss: 5.870923042297363
epoch: 0, update in batch 48000/???, loss: 6.286317348480225
epoch: 0, update in batch 49000/???, loss: 7.157052516937256
epoch: 0, update in batch 50000/???, loss: 5.888463020324707
epoch: 0, update in batch 51000/???, loss: 5.609915733337402
epoch: 0, update in batch 52000/???, loss: 6.565190315246582
epoch: 0, update in batch 53000/???, loss: 6.4924468994140625
epoch: 0, update in batch 54000/???, loss: 6.856420040130615
epoch: 0, update in batch 55000/???, loss: 7.389428615570068
epoch: 0, update in batch 56000/???, loss: 5.927685260772705
epoch: 0, update in batch 57000/???, loss: 7.4227423667907715
epoch: 0, update in batch 58000/???, loss: 6.46466064453125
epoch: 0, update in batch 59000/???, loss: 6.586294651031494
epoch: 0, update in batch 60000/???, loss: 5.797083854675293
epoch: 0, update in batch 61000/???, loss: 4.825878143310547
epoch: 0, update in batch 62000/???, loss: 6.911933898925781
epoch: 0, update in batch 63000/???, loss: 7.684759616851807
epoch: 0, update in batch 64000/???, loss: 5.716580390930176
epoch: 0, update in batch 65000/???, loss: 6.1738667488098145
epoch: 0, update in batch 66000/???, loss: 6.219714164733887
epoch: 0, update in batch 67000/???, loss: 5.4024128913879395
epoch: 0, update in batch 68000/???, loss: 6.912312984466553
epoch: 0, update in batch 69000/???, loss: 6.703289031982422
epoch: 0, update in batch 70000/???, loss: 7.375630855560303
epoch: 0, update in batch 71000/???, loss: 5.757082462310791
epoch: 0, update in batch 72000/???, loss: 5.992405414581299
epoch: 0, update in batch 73000/???, loss: 6.706838130950928
epoch: 0, update in batch 74000/???, loss: 7.376870155334473
epoch: 0, update in batch 75000/???, loss: 6.676860809326172
epoch: 0, update in batch 76000/???, loss: 5.904101848602295
epoch: 0, update in batch 77000/???, loss: 6.776932716369629
epoch: 0, update in batch 78000/???, loss: 5.682181358337402
epoch: 0, update in batch 79000/???, loss: 6.211178302764893
epoch: 0, update in batch 80000/???, loss: 6.366950035095215
epoch: 0, update in batch 81000/???, loss: 5.25206184387207
epoch: 0, update in batch 82000/???, loss: 6.30997371673584
epoch: 0, update in batch 83000/???, loss: 6.351908206939697
epoch: 0, update in batch 84000/???, loss: 7.659114837646484
epoch: 0, update in batch 85000/???, loss: 6.5041704177856445
epoch: 0, update in batch 86000/???, loss: 6.770291328430176
epoch: 0, update in batch 87000/???, loss: 6.530011177062988
epoch: 0, update in batch 88000/???, loss: 6.317249298095703
epoch: 0, update in batch 89000/???, loss: 6.191559314727783
epoch: 0, update in batch 90000/???, loss: 5.79150390625
epoch: 0, update in batch 91000/???, loss: 6.356796741485596
epoch: 0, update in batch 92000/???, loss: 7.3577141761779785
epoch: 0, update in batch 93000/???, loss: 6.529308319091797
epoch: 0, update in batch 94000/???, loss: 7.740485191345215
epoch: 0, update in batch 95000/???, loss: 6.348109245300293
epoch: 0, update in batch 96000/???, loss: 6.032902717590332
epoch: 0, update in batch 97000/???, loss: 4.505112648010254
epoch: 0, update in batch 98000/???, loss: 6.946290493011475
epoch: 0, update in batch 99000/???, loss: 6.237973213195801
epoch: 0, update in batch 100000/???, loss: 6.963421821594238
epoch: 0, update in batch 101000/???, loss: 5.309017181396484
epoch: 0, update in batch 102000/???, loss: 6.242384910583496
epoch: 0, update in batch 103000/???, loss: 6.8203558921813965
epoch: 0, update in batch 104000/???, loss: 6.242025852203369
epoch: 0, update in batch 105000/???, loss: 6.765100002288818
epoch: 0, update in batch 106000/???, loss: 6.8838043212890625
epoch: 0, update in batch 107000/???, loss: 6.856662750244141
epoch: 0, update in batch 108000/???, loss: 6.379549503326416
epoch: 0, update in batch 109000/???, loss: 6.797707557678223
epoch: 0, update in batch 110000/???, loss: 7.2699689865112305
epoch: 0, update in batch 111000/???, loss: 7.040057182312012
epoch: 0, update in batch 112000/???, loss: 6.7861223220825195
epoch: 0, update in batch 113000/???, loss: 6.064489364624023
epoch: 0, update in batch 114000/???, loss: 6.095967769622803
epoch: 0, update in batch 115000/???, loss: 5.757347106933594
epoch: 0, update in batch 116000/???, loss: 6.529908657073975
epoch: 0, update in batch 117000/???, loss: 6.030801296234131
epoch: 0, update in batch 118000/???, loss: 6.179767608642578
epoch: 0, update in batch 119000/???, loss: 5.436234474182129
epoch: 0, update in batch 120000/???, loss: 7.342876434326172
epoch: 0, update in batch 121000/???, loss: 6.862719535827637
epoch: 0, update in batch 122000/???, loss: 6.491606712341309
epoch: 0, update in batch 123000/???, loss: 7.195406436920166
epoch: 0, update in batch 124000/???, loss: 5.481313228607178
epoch: 0, update in batch 125000/???, loss: 7.963885307312012
epoch: 0, update in batch 126000/???, loss: 6.479039669036865
epoch: 0, update in batch 127000/???, loss: 7.037934303283691
epoch: 0, update in batch 128000/???, loss: 5.903053283691406
epoch: 0, update in batch 129000/???, loss: 6.815878391265869
epoch: 0, update in batch 130000/???, loss: 6.497969150543213
epoch: 0, update in batch 131000/???, loss: 5.623625755310059
epoch: 0, update in batch 132000/???, loss: 7.118441104888916
epoch: 0, update in batch 133000/???, loss: 5.964345455169678
epoch: 0, update in batch 134000/???, loss: 6.112139701843262
epoch: 0, update in batch 135000/???, loss: 6.5865373611450195
epoch: 0, update in batch 136000/???, loss: 7.498536109924316
epoch: 0, update in batch 137000/???, loss: 7.124758243560791
epoch: 0, update in batch 138000/???, loss: 6.871796607971191
epoch: 0, update in batch 139000/???, loss: 5.8565263748168945
epoch: 0, update in batch 140000/???, loss: 5.723143577575684
epoch: 0, update in batch 141000/???, loss: 5.601426124572754
epoch: 0, update in batch 142000/???, loss: 5.495566368103027
epoch: 0, update in batch 143000/???, loss: 6.936192989349365
epoch: 0, update in batch 144000/???, loss: 6.1843671798706055
epoch: 0, update in batch 145000/???, loss: 6.886034965515137
epoch: 0, update in batch 146000/???, loss: 6.655320167541504
epoch: 0, update in batch 147000/???, loss: 6.46828556060791
epoch: 0, update in batch 148000/???, loss: 5.607057571411133
epoch: 0, update in batch 149000/???, loss: 7.182212829589844
epoch: 0, update in batch 150000/???, loss: 7.241323947906494
epoch: 0, update in batch 151000/???, loss: 7.308540344238281
epoch: 0, update in batch 152000/???, loss: 5.267911434173584
epoch: 0, update in batch 153000/???, loss: 5.895949363708496
epoch: 0, update in batch 154000/???, loss: 6.629178524017334
epoch: 0, update in batch 155000/???, loss: 4.9156012535095215
epoch: 0, update in batch 156000/???, loss: 7.181819915771484
epoch: 0, update in batch 157000/???, loss: 7.438391208648682
epoch: 0, update in batch 158000/???, loss: 6.406006813049316
epoch: 0, update in batch 159000/???, loss: 6.486207008361816
epoch: 0, update in batch 160000/???, loss: 7.041951656341553
epoch: 0, update in batch 161000/???, loss: 5.310082912445068
epoch: 0, update in batch 162000/???, loss: 6.9074387550354
epoch: 0, update in batch 163000/???, loss: 6.644919395446777
epoch: 0, update in batch 164000/???, loss: 6.011733055114746
epoch: 0, update in batch 165000/???, loss: 6.494180202484131
epoch: 0, update in batch 166000/???, loss: 5.390150547027588
epoch: 0, update in batch 167000/???, loss: 6.627297401428223
epoch: 0, update in batch 168000/???, loss: 6.9020209312438965
epoch: 0, update in batch 169000/???, loss: 7.317750453948975
epoch: 0, update in batch 170000/???, loss: 5.69993782043457
epoch: 0, update in batch 171000/???, loss: 6.658817291259766
epoch: 0, update in batch 172000/???, loss: 6.422945976257324
epoch: 0, update in batch 173000/???, loss: 5.822269439697266
epoch: 0, update in batch 174000/???, loss: 6.513391017913818
epoch: 0, update in batch 175000/???, loss: 5.886590957641602
epoch: 0, update in batch 176000/???, loss: 7.119387149810791
epoch: 0, update in batch 177000/???, loss: 6.933981418609619
epoch: 0, update in batch 178000/???, loss: 6.678143501281738
epoch: 0, update in batch 179000/???, loss: 6.890423774719238
epoch: 0, update in batch 180000/???, loss: 6.932961940765381
epoch: 0, update in batch 181000/???, loss: 6.650975704193115
epoch: 0, update in batch 182000/???, loss: 6.732748985290527
epoch: 0, update in batch 183000/???, loss: 6.064764976501465
epoch: 0, update in batch 184000/???, loss: 5.282295227050781
epoch: 0, update in batch 185000/???, loss: 6.569302558898926
epoch: 0, update in batch 186000/???, loss: 5.800485610961914
epoch: 0, update in batch 187000/???, loss: 6.175991058349609
epoch: 0, update in batch 188000/???, loss: 5.405575752258301
epoch: 0, update in batch 189000/???, loss: 6.191354751586914
epoch: 0, update in batch 190000/???, loss: 6.156663417816162
epoch: 0, update in batch 191000/???, loss: 6.937534332275391
epoch: 0, update in batch 192000/???, loss: 6.562686920166016
epoch: 0, update in batch 193000/???, loss: 6.639985084533691
epoch: 0, update in batch 194000/???, loss: 7.285438537597656
epoch: 0, update in batch 195000/???, loss: 6.528258323669434
epoch: 0, update in batch 196000/???, loss: 8.326434135437012
epoch: 0, update in batch 197000/???, loss: 6.781360626220703
epoch: 0, update in batch 198000/???, loss: 7.223299980163574
epoch: 0, update in batch 199000/???, loss: 6.411007881164551
epoch: 0, update in batch 200000/???, loss: 5.885635852813721
epoch: 0, update in batch 201000/???, loss: 5.706809043884277
epoch: 0, update in batch 202000/???, loss: 6.230217933654785
epoch: 0, update in batch 203000/???, loss: 7.056562900543213
epoch: 0, update in batch 204000/???, loss: 7.2273077964782715
epoch: 0, update in batch 205000/???, loss: 6.342462539672852
epoch: 0, update in batch 206000/???, loss: 6.556817054748535
epoch: 0, update in batch 207000/???, loss: 5.882349967956543
epoch: 0, update in batch 208000/???, loss: 6.755805015563965
epoch: 0, update in batch 209000/???, loss: 6.5045623779296875
epoch: 0, update in batch 210000/???, loss: 6.525590419769287
epoch: 0, update in batch 211000/???, loss: 6.49679708480835
epoch: 0, update in batch 212000/???, loss: 6.562323093414307
epoch: 0, update in batch 213000/???, loss: 5.227139472961426
epoch: 0, update in batch 214000/???, loss: 7.044825077056885
epoch: 0, update in batch 215000/???, loss: 6.002442359924316
epoch: 0, update in batch 216000/???, loss: 6.084803581237793
epoch: 0, update in batch 217000/???, loss: 7.425839900970459
epoch: 0, update in batch 218000/???, loss: 6.818853855133057
epoch: 0, update in batch 219000/???, loss: 7.0153374671936035
epoch: 0, update in batch 220000/???, loss: 6.219962120056152
epoch: 0, update in batch 221000/???, loss: 5.9975385665893555
epoch: 0, update in batch 222000/???, loss: 6.480047702789307
epoch: 0, update in batch 223000/???, loss: 6.405727386474609
epoch: 0, update in batch 224000/???, loss: 4.7763471603393555
epoch: 0, update in batch 225000/???, loss: 6.615710258483887
epoch: 0, update in batch 226000/???, loss: 6.385044574737549
epoch: 0, update in batch 227000/???, loss: 7.260453701019287
epoch: 0, update in batch 228000/???, loss: 6.9794135093688965
epoch: 0, update in batch 229000/???, loss: 6.235735893249512
epoch: 0, update in batch 230000/???, loss: 6.478426456451416
epoch: 0, update in batch 231000/???, loss: 6.181302547454834
epoch: 0, update in batch 232000/???, loss: 5.826043128967285
epoch: 0, update in batch 233000/???, loss: 5.9517011642456055
epoch: 0, update in batch 234000/???, loss: 8.0064697265625
epoch: 0, update in batch 235000/???, loss: 6.7822675704956055
epoch: 0, update in batch 236000/???, loss: 6.293349742889404
epoch: 0, update in batch 237000/???, loss: 6.442999362945557
epoch: 0, update in batch 238000/???, loss: 6.282561302185059
epoch: 0, update in batch 239000/???, loss: 7.166723728179932
epoch: 0, update in batch 240000/???, loss: 7.189905643463135
epoch: 0, update in batch 241000/???, loss: 8.462562561035156
epoch: 0, update in batch 242000/???, loss: 7.446291923522949
epoch: 0, update in batch 243000/???, loss: 6.382981777191162
epoch: 0, update in batch 244000/???, loss: 7.635994911193848
epoch: 0, update in batch 245000/???, loss: 6.635537147521973
epoch: 0, update in batch 246000/???, loss: 6.068560600280762
epoch: 0, update in batch 247000/???, loss: 6.193384170532227
epoch: 0, update in batch 248000/???, loss: 5.702363967895508
epoch: 0, update in batch 249000/???, loss: 6.09995174407959
epoch: 0, update in batch 250000/???, loss: 6.312221050262451
epoch: 0, update in batch 251000/???, loss: 5.853858470916748
epoch: 0, update in batch 252000/???, loss: 5.886989593505859
epoch: 0, update in batch 253000/???, loss: 5.801788330078125
epoch: 0, update in batch 254000/???, loss: 6.032407760620117
epoch: 0, update in batch 255000/???, loss: 7.480917453765869
epoch: 0, update in batch 256000/???, loss: 6.578718662261963
epoch: 0, update in batch 257000/???, loss: 6.344462871551514
epoch: 0, update in batch 258000/???, loss: 5.939858436584473
epoch: 0, update in batch 259000/???, loss: 5.181772232055664
epoch: 0, update in batch 260000/???, loss: 6.640598297119141
epoch: 0, update in batch 261000/???, loss: 7.189258575439453
epoch: 0, update in batch 262000/???, loss: 6.2269287109375
epoch: 0, update in batch 263000/???, loss: 5.8858795166015625
epoch: 0, update in batch 264000/???, loss: 6.333988666534424
epoch: 0, update in batch 265000/???, loss: 6.313681602478027
epoch: 0, update in batch 266000/???, loss: 5.485809803009033
epoch: 0, update in batch 267000/???, loss: 6.250800609588623
epoch: 0, update in batch 268000/???, loss: 6.676806449890137
epoch: 0, update in batch 269000/???, loss: 5.6487932205200195
epoch: 0, update in batch 270000/???, loss: 6.648938179016113
epoch: 0, update in batch 271000/???, loss: 6.26931095123291
epoch: 0, update in batch 272000/???, loss: 5.343636512756348
epoch: 0, update in batch 273000/???, loss: 7.051453590393066
epoch: 0, update in batch 274000/???, loss: 4.578436851501465
epoch: 0, update in batch 275000/???, loss: 5.400996685028076
epoch: 0, update in batch 276000/???, loss: 6.129047870635986
epoch: 0, update in batch 277000/???, loss: 7.549851894378662
epoch: 0, update in batch 278000/???, loss: 6.093559265136719
epoch: 0, update in batch 279000/???, loss: 5.6921467781066895
epoch: 0, update in batch 280000/???, loss: 5.789463996887207
epoch: 0, update in batch 281000/???, loss: 5.681942939758301
epoch: 0, update in batch 282000/???, loss: 6.750497341156006
epoch: 0, update in batch 283000/???, loss: 5.960292339324951
epoch: 0, update in batch 284000/???, loss: 6.160388469696045
epoch: 0, update in batch 285000/???, loss: 7.137685298919678
epoch: 0, update in batch 286000/???, loss: 7.7431464195251465
epoch: 0, update in batch 287000/???, loss: 5.229738712310791
epoch: 0, update in batch 288000/???, loss: 6.654232025146484
epoch: 0, update in batch 289000/???, loss: 6.229329586029053
epoch: 0, update in batch 290000/???, loss: 7.188180446624756
epoch: 0, update in batch 291000/???, loss: 6.244111061096191
epoch: 0, update in batch 292000/???, loss: 7.199154853820801
epoch: 0, update in batch 293000/???, loss: 7.1866865158081055
epoch: 0, update in batch 294000/???, loss: 6.574115753173828
epoch: 0, update in batch 295000/???, loss: 6.487138271331787
epoch: 0, update in batch 296000/???, loss: 5.813161849975586
epoch: 0, update in batch 297000/???, loss: 6.159414291381836
epoch: 0, update in batch 298000/???, loss: 7.256616115570068
epoch: 0, update in batch 299000/???, loss: 7.511231899261475
epoch: 0, update in batch 300000/???, loss: 6.148821830749512
epoch: 0, update in batch 301000/???, loss: 7.108969211578369
epoch: 0, update in batch 302000/???, loss: 6.528176307678223
epoch: 0, update in batch 303000/???, loss: 6.276839256286621
epoch: 0, update in batch 304000/???, loss: 6.484020233154297
epoch: 0, update in batch 305000/???, loss: 6.38557767868042
epoch: 0, update in batch 306000/???, loss: 7.068814754486084
epoch: 0, update in batch 307000/???, loss: 5.844017505645752
epoch: 0, update in batch 308000/???, loss: 4.25785493850708
epoch: 0, update in batch 309000/???, loss: 6.709985256195068
epoch: 0, update in batch 310000/???, loss: 6.543104648590088
epoch: 0, update in batch 311000/???, loss: 6.675828456878662
epoch: 0, update in batch 312000/???, loss: 5.82969856262207
epoch: 0, update in batch 313000/???, loss: 6.05246639251709
epoch: 0, update in batch 314000/???, loss: 7.2366156578063965
epoch: 0, update in batch 315000/???, loss: 5.039820194244385
epoch: 0, update in batch 316000/???, loss: 5.943173885345459
epoch: 0, update in batch 317000/???, loss: 6.2509002685546875
epoch: 0, update in batch 318000/???, loss: 6.451228141784668
epoch: 0, update in batch 319000/???, loss: 6.6381049156188965
epoch: 0, update in batch 320000/???, loss: 6.570329189300537
epoch: 0, update in batch 321000/???, loss: 5.376622200012207
epoch: 0, update in batch 322000/???, loss: 6.487462520599365
epoch: 0, update in batch 323000/???, loss: 6.676497459411621
epoch: 0, update in batch 324000/???, loss: 6.283420562744141
epoch: 0, update in batch 325000/???, loss: 6.164648532867432
epoch: 0, update in batch 326000/???, loss: 6.839153289794922
epoch: 0, update in batch 327000/???, loss: 6.435141086578369
epoch: 0, update in batch 328000/???, loss: 6.160590171813965
epoch: 0, update in batch 329000/???, loss: 5.876160621643066
epoch: 0, update in batch 330000/???, loss: 6.47445011138916
epoch: 0, update in batch 331000/???, loss: 6.294231414794922
epoch: 0, update in batch 332000/???, loss: 6.099027156829834
epoch: 0, update in batch 333000/???, loss: 6.986542701721191
epoch: 0, update in batch 334000/???, loss: 7.018263816833496
epoch: 0, update in batch 335000/???, loss: 6.906959533691406
epoch: 0, update in batch 336000/???, loss: 6.12356424331665
epoch: 0, update in batch 337000/???, loss: 6.316069602966309
epoch: 0, update in batch 338000/???, loss: 6.908566474914551
epoch: 0, update in batch 339000/???, loss: 5.628839492797852
epoch: 0, update in batch 340000/???, loss: 7.069979667663574
epoch: 0, update in batch 341000/???, loss: 5.350735187530518
epoch: 0, update in batch 342000/???, loss: 5.377245903015137
epoch: 0, update in batch 343000/???, loss: 5.2340989112854
epoch: 0, update in batch 344000/???, loss: 6.087491512298584
epoch: 0, update in batch 345000/???, loss: 6.162985801696777
epoch: 0, update in batch 346000/???, loss: 5.655491828918457
epoch: 0, update in batch 347000/???, loss: 5.311842918395996
epoch: 0, update in batch 348000/???, loss: 7.577170372009277
epoch: 0, update in batch 349000/???, loss: 6.730460166931152
epoch: 0, update in batch 350000/???, loss: 6.782231330871582
epoch: 0, update in batch 351000/???, loss: 6.789486885070801
epoch: 0, update in batch 352000/???, loss: 5.473587989807129
epoch: 0, update in batch 353000/???, loss: 5.531443119049072
epoch: 0, update in batch 354000/???, loss: 7.220989227294922
epoch: 0, update in batch 355000/???, loss: 5.954288005828857
epoch: 0, update in batch 356000/???, loss: 4.112783432006836
epoch: 0, update in batch 357000/???, loss: 5.409672737121582
epoch: 0, update in batch 358000/???, loss: 6.408724784851074
epoch: 0, update in batch 359000/???, loss: 6.744941711425781
epoch: 0, update in batch 360000/???, loss: 6.218225479125977
epoch: 0, update in batch 361000/???, loss: 6.071394920349121
epoch: 0, update in batch 362000/???, loss: 6.137121677398682
epoch: 0, update in batch 363000/???, loss: 5.876864433288574
epoch: 0, update in batch 364000/???, loss: 7.715426445007324
epoch: 0, update in batch 365000/???, loss: 6.217362880706787
epoch: 0, update in batch 366000/???, loss: 6.741396903991699
epoch: 0, update in batch 367000/???, loss: 6.4564313888549805
epoch: 0, update in batch 368000/???, loss: 6.994439601898193
epoch: 0, update in batch 369000/???, loss: 6.061278820037842
epoch: 0, update in batch 370000/???, loss: 4.894576549530029
epoch: 0, update in batch 371000/???, loss: 6.351264953613281
epoch: 0, update in batch 372000/???, loss: 6.826904296875
epoch: 0, update in batch 373000/???, loss: 6.090312480926514
epoch: 0, update in batch 374000/???, loss: 5.797528266906738
epoch: 0, update in batch 375000/???, loss: 7.3235673904418945
epoch: 0, update in batch 376000/???, loss: 5.5752973556518555
epoch: 0, update in batch 377000/???, loss: 6.29438591003418
epoch: 0, update in batch 378000/???, loss: 5.238917827606201
epoch: 0, update in batch 379000/???, loss: 5.542972564697266
epoch: 0, update in batch 380000/???, loss: 6.5024614334106445
epoch: 0, update in batch 381000/???, loss: 6.918997287750244
epoch: 0, update in batch 382000/???, loss: 5.694029331207275
epoch: 0, update in batch 383000/???, loss: 7.109190940856934
epoch: 0, update in batch 384000/???, loss: 5.214654445648193
epoch: 0, update in batch 385000/???, loss: 7.055975437164307
epoch: 0, update in batch 386000/???, loss: 6.443846225738525
epoch: 0, update in batch 387000/???, loss: 5.544674873352051
epoch: 0, update in batch 388000/???, loss: 6.936171531677246
epoch: 0, update in batch 389000/???, loss: 6.646860599517822
epoch: 0, update in batch 390000/???, loss: 6.193584442138672
epoch: 0, update in batch 391000/???, loss: 5.9077558517456055
epoch: 0, update in batch 392000/???, loss: 5.029908657073975
epoch: 0, update in batch 393000/???, loss: 6.725222587585449
epoch: 0, update in batch 394000/???, loss: 6.814855098724365
epoch: 0, update in batch 395000/???, loss: 7.396543979644775
epoch: 0, update in batch 396000/???, loss: 6.993375301361084
epoch: 0, update in batch 397000/???, loss: 6.224326133728027
epoch: 0, update in batch 398000/???, loss: 6.301025390625
epoch: 0, update in batch 399000/???, loss: 6.707190036773682
epoch: 0, update in batch 400000/???, loss: 6.7646660804748535
epoch: 0, update in batch 401000/???, loss: 5.308177471160889
epoch: 0, update in batch 402000/???, loss: 8.35996150970459
epoch: 0, update in batch 403000/???, loss: 5.825610160827637
epoch: 0, update in batch 404000/???, loss: 6.310220718383789
epoch: 0, update in batch 405000/???, loss: 5.759210109710693
epoch: 0, update in batch 406000/???, loss: 6.32699728012085
epoch: 0, update in batch 407000/???, loss: 5.659378528594971
epoch: 0, update in batch 408000/???, loss: 6.216103553771973
epoch: 0, update in batch 409000/???, loss: 5.666914463043213
epoch: 0, update in batch 410000/???, loss: 6.419122219085693
epoch: 0, update in batch 411000/???, loss: 5.372750282287598
epoch: 0, update in batch 412000/???, loss: 6.839580535888672
epoch: 0, update in batch 413000/???, loss: 6.7682647705078125
epoch: 0, update in batch 414000/???, loss: 5.951648235321045
epoch: 0, update in batch 415000/???, loss: 6.181953430175781
epoch: 0, update in batch 416000/???, loss: 5.475704669952393
epoch: 0, update in batch 417000/???, loss: 6.383082866668701
epoch: 0, update in batch 418000/???, loss: 6.8107590675354
epoch: 0, update in batch 419000/???, loss: 5.753104209899902
epoch: 0, update in batch 420000/???, loss: 5.320840835571289
epoch: 0, update in batch 421000/???, loss: 7.377203464508057
epoch: 0, update in batch 422000/???, loss: 6.5706048011779785
epoch: 0, update in batch 423000/???, loss: 5.032872676849365
epoch: 0, update in batch 424000/???, loss: 5.781243324279785
epoch: 0, update in batch 425000/???, loss: 6.160118579864502
def predict_probs(tokens):
    model.eval()
    state_h = model.init_state(len(tokens))

    x = torch.tensor([[train_dataset.key_to_index[w] if w in key_to_index else train_dataset.key_to_index['<unk>'] for w in tokens]]).to(device)
    y_pred, state_h = model(x)

    last_word_logits = y_pred[0][-1]
    probs = torch.nn.functional.softmax(last_word_logits, dim=0).detach().cpu().numpy()
    word_index = np.random.choice(len(last_word_logits), p=probs)
    
    top_words = []
    for index in range(len(probs)):
        if len(top_words) < 30:
            top_words.append((probs[index], [index]))
        else:
            worst_word = None
            for word in top_words:
                if not worst_word:
                    worst_word = word
                else:
                    if word[0] < worst_word[0]:
                        worst_word = word
            if worst_word[0] < probs[index] and index != len(probs) - 1:
                top_words.remove(worst_word)
                top_words.append((probs[index], [index]))
                
    prediction = ''
    sum_prob = 0.0
    for word in top_words:
        sum_prob += word[0]
        word_index = word[0]
        word_text = index_to_key[word[1][0]]
        prediction += f'{word_text}:{word_index} '
    prediction += f':{1 - sum_prob}'
    
    return prediction
dev_data = pd.read_csv('dev-0/in.tsv.xz', sep='\t', on_bad_lines='skip', header=None, quoting=csv.QUOTE_NONE)
test_data = pd.read_csv('test-A/in.tsv.xz', sep='\t', on_bad_lines='skip', header=None, quoting=csv.QUOTE_NONE)
with open('dev-0/out.tsv', 'w') as file:
    for index, row in dev_data.iterrows():
        left_text = clean_text(str(row[6]))
        left_words = word_tokenize(left_text)
        if len(left_words) < 6:
            prediction = ':1.0'
        else:
            prediction = predict_probs(left_words[-5:])
        file.write(prediction + '\n')
with open('test-A/out.tsv', 'w') as file:
    for index, row in test_data.iterrows():
        left_text = clean_text(str(row[6]))
        left_words = word_tokenize(left_text)
        if len(left_words) < 6:
            prediction = ':1.0'
        else:
            prediction = predict_probs(left_words[-5:])
        file.write(prediction + '\n')