challenging-america-word-ga.../run.ipynb

42 KiB

import pandas as pd
import numpy as np
import regex as re
import csv
import torch
from torch import nn
from gensim.models import Word2Vec
from nltk.tokenize import word_tokenize
torch.cuda.empty_cache()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
def clean_text(text):
    text = text.lower().replace('-\\\\\\\\\\\\\\\\n', '').replace('\\\\\\\\\\\\\\\\n', ' ')
    text = re.sub(r'\p{P}', '', text)
    text = text.replace("'t", " not").replace("'s", " is").replace("'ll", " will").replace("'m", " am").replace("'ve", " have")

    return text
train_data = pd.read_csv('train/in.tsv.xz', sep='\t', on_bad_lines='skip', header=None, quoting=csv.QUOTE_NONE)
train_labels = pd.read_csv('train/expected.tsv', sep='\t', on_bad_lines='skip', header=None, quoting=csv.QUOTE_NONE)

train_data = train_data[[6, 7]]
train_data = pd.concat([train_data, train_labels], axis=1)
class TrainCorpus:
    def __init__(self, data):
        self.data = data
        
    def __iter__(self):
        for _, row in self.data.iterrows():
            text = str(row[6]) + str(row[0]) + str(row[7])
            text = clean_text(text)
            yield word_tokenize(text)
train_sentences = TrainCorpus(train_data.head(100000))
w2v_model = Word2Vec(vector_size=100, min_count=10)
w2v_model.build_vocab(corpus_iterable=train_sentences)

key_to_index = w2v_model.wv.key_to_index
index_to_key = w2v_model.wv.index_to_key

index_to_key.append('<unk>')
key_to_index['<unk>'] = len(index_to_key) - 1

vocab_size = len(index_to_key)
print(vocab_size)
97122
class TrainDataset(torch.utils.data.IterableDataset):
    def __init__(self, data, index_to_key, key_to_index):
        self.data = data
        self.index_to_key = index_to_key
        self.key_to_index = key_to_index
        self.vocab_size = len(key_to_index)

    def __iter__(self):
        for _, row in self.data.iterrows():
            text = str(row[6]) + str(row[0]) + str(row[7])
            text = clean_text(text)
            tokens = word_tokenize(text)
            for i in range(5, len(tokens), 1):
                input_context = tokens[i-5:i]
                target_context = tokens[i-4:i+1]
                #gap_word = tokens[i]
            
                input_embed = [self.key_to_index[word] if word in self.key_to_index else self.key_to_index['<unk>'] for word in input_context]
                target_embed = [self.key_to_index[word] if word in self.key_to_index else self.key_to_index['<unk>'] for word in target_context]
                #word_index = self.key_to_index[gap_word] if gap_word in self.key_to_index else self.key_to_index['<unk>']
                #word_embed = np.concatenate([np.zeros(word_index), np.ones(1), np.zeros(vocab_size - word_index - 1)])
                
                yield np.asarray(input_embed, dtype=np.int64), np.asarray(target_embed, dtype=np.int64)
class Model(nn.Module):
    def __init__(self, embed_size, vocab_size):
        super(Model, self).__init__()
        self.embed_size = embed_size
        self.vocab_size = vocab_size
        self.gru_size = 128
        self.num_layers = 2
        
        self.embed = nn.Embedding(num_embeddings=vocab_size, embedding_dim=self.embed_size)
        self.gru = nn.GRU(input_size=self.embed_size, hidden_size=self.gru_size, num_layers=self.num_layers, dropout=0.2)
        self.fc = nn.Linear(self.gru_size, vocab_size)

    def forward(self, x, prev_state = None):
        embed = self.embed(x)
        output, state = self.gru(embed, prev_state)
        logits = self.fc(output)
        probs = torch.softmax(logits, dim=1)
        return logits, state

    def init_state(self, sequence_length):
        zeros = torch.zeros(self.num_layers, sequence_length, self.gru_size).to(device)
        return (zeros, zeros)
from torch.utils.data import DataLoader
from torch.optim import Adam

def train(dataset, model, max_epochs, batch_size):
    model.train()

    dataloader = DataLoader(dataset, batch_size=batch_size)
    criterion = nn.CrossEntropyLoss()
    optimizer = Adam(model.parameters(), lr=0.001)

    for epoch in range(max_epochs):
        for batch, (x, y) in enumerate(dataloader):
            optimizer.zero_grad()
            
            x = x.to(device)
            y = y.to(device)
            
            y_pred, state_h = model(x)
            loss = criterion(y_pred.transpose(1, 2), y)

            loss.backward()
            optimizer.step()
            
            if batch % 1000 == 0:
                print(f'epoch: {epoch}, update in batch {batch}/???, loss: {loss.item()}')
train_dataset = TrainDataset(train_data.head(100000), index_to_key, key_to_index)
model = Model(100, vocab_size).to(device)
train(train_dataset, model, 1, 64)
epoch: 0, update in batch 0/???, loss: 11.47425365447998
epoch: 0, update in batch 1000/???, loss: 7.042205810546875
epoch: 0, update in batch 2000/???, loss: 7.235440731048584
epoch: 0, update in batch 3000/???, loss: 7.251269340515137
epoch: 0, update in batch 4000/???, loss: 6.944191932678223
epoch: 0, update in batch 5000/???, loss: 6.263372421264648
epoch: 0, update in batch 6000/???, loss: 6.181947231292725
epoch: 0, update in batch 7000/???, loss: 6.508013725280762
epoch: 0, update in batch 8000/???, loss: 6.658236026763916
epoch: 0, update in batch 9000/???, loss: 6.536279201507568
epoch: 0, update in batch 10000/???, loss: 6.802626609802246
epoch: 0, update in batch 11000/???, loss: 6.961029052734375
epoch: 0, update in batch 12000/???, loss: 7.713824272155762
epoch: 0, update in batch 13000/???, loss: 8.100411415100098
epoch: 0, update in batch 14000/???, loss: 6.457145690917969
epoch: 0, update in batch 15000/???, loss: 6.850286960601807
epoch: 0, update in batch 16000/???, loss: 6.794063568115234
epoch: 0, update in batch 17000/???, loss: 6.311314582824707
epoch: 0, update in batch 18000/???, loss: 6.611917018890381
epoch: 0, update in batch 19000/???, loss: 5.679810523986816
epoch: 0, update in batch 20000/???, loss: 7.110655307769775
epoch: 0, update in batch 21000/???, loss: 7.170722961425781
epoch: 0, update in batch 22000/???, loss: 6.584908485412598
epoch: 0, update in batch 23000/???, loss: 7.224095344543457
epoch: 0, update in batch 24000/???, loss: 5.827445983886719
epoch: 0, update in batch 25000/???, loss: 6.444586753845215
epoch: 0, update in batch 26000/???, loss: 6.149054527282715
epoch: 0, update in batch 27000/???, loss: 6.259871482849121
epoch: 0, update in batch 28000/???, loss: 5.789839744567871
epoch: 0, update in batch 29000/???, loss: 7.025563716888428
epoch: 0, update in batch 30000/???, loss: 7.265492916107178
epoch: 0, update in batch 31000/???, loss: 4.921586036682129
epoch: 0, update in batch 32000/???, loss: 6.467754364013672
epoch: 0, update in batch 33000/???, loss: 7.393715858459473
epoch: 0, update in batch 34000/???, loss: 6.9696760177612305
epoch: 0, update in batch 35000/???, loss: 7.276318550109863
epoch: 0, update in batch 36000/???, loss: 7.011231899261475
epoch: 0, update in batch 37000/???, loss: 7.029260158538818
epoch: 0, update in batch 38000/???, loss: 6.723126411437988
epoch: 0, update in batch 39000/???, loss: 6.828773498535156
epoch: 0, update in batch 40000/???, loss: 6.069770336151123
epoch: 0, update in batch 41000/???, loss: 6.651298522949219
epoch: 0, update in batch 42000/???, loss: 7.455380916595459
epoch: 0, update in batch 43000/???, loss: 5.594773769378662
epoch: 0, update in batch 44000/???, loss: 6.102865219116211
epoch: 0, update in batch 45000/???, loss: 6.04202127456665
epoch: 0, update in batch 46000/???, loss: 6.472177982330322
epoch: 0, update in batch 47000/???, loss: 5.870923042297363
epoch: 0, update in batch 48000/???, loss: 6.286317348480225
epoch: 0, update in batch 49000/???, loss: 7.157052516937256
epoch: 0, update in batch 50000/???, loss: 5.888463020324707
epoch: 0, update in batch 51000/???, loss: 5.609915733337402
epoch: 0, update in batch 52000/???, loss: 6.565190315246582
epoch: 0, update in batch 53000/???, loss: 6.4924468994140625
epoch: 0, update in batch 54000/???, loss: 6.856420040130615
epoch: 0, update in batch 55000/???, loss: 7.389428615570068
epoch: 0, update in batch 56000/???, loss: 5.927685260772705
epoch: 0, update in batch 57000/???, loss: 7.4227423667907715
epoch: 0, update in batch 58000/???, loss: 6.46466064453125
epoch: 0, update in batch 59000/???, loss: 6.586294651031494
epoch: 0, update in batch 60000/???, loss: 5.797083854675293
epoch: 0, update in batch 61000/???, loss: 4.825878143310547
epoch: 0, update in batch 62000/???, loss: 6.911933898925781
epoch: 0, update in batch 63000/???, loss: 7.684759616851807
epoch: 0, update in batch 64000/???, loss: 5.716580390930176
epoch: 0, update in batch 65000/???, loss: 6.1738667488098145
epoch: 0, update in batch 66000/???, loss: 6.219714164733887
epoch: 0, update in batch 67000/???, loss: 5.4024128913879395
epoch: 0, update in batch 68000/???, loss: 6.912312984466553
epoch: 0, update in batch 69000/???, loss: 6.703289031982422
epoch: 0, update in batch 70000/???, loss: 7.375630855560303
epoch: 0, update in batch 71000/???, loss: 5.757082462310791
epoch: 0, update in batch 72000/???, loss: 5.992405414581299
epoch: 0, update in batch 73000/???, loss: 6.706838130950928
epoch: 0, update in batch 74000/???, loss: 7.376870155334473
epoch: 0, update in batch 75000/???, loss: 6.676860809326172
epoch: 0, update in batch 76000/???, loss: 5.904101848602295
epoch: 0, update in batch 77000/???, loss: 6.776932716369629
epoch: 0, update in batch 78000/???, loss: 5.682181358337402
epoch: 0, update in batch 79000/???, loss: 6.211178302764893
epoch: 0, update in batch 80000/???, loss: 6.366950035095215
epoch: 0, update in batch 81000/???, loss: 5.25206184387207
epoch: 0, update in batch 82000/???, loss: 6.30997371673584
epoch: 0, update in batch 83000/???, loss: 6.351908206939697
epoch: 0, update in batch 84000/???, loss: 7.659114837646484
epoch: 0, update in batch 85000/???, loss: 6.5041704177856445
epoch: 0, update in batch 86000/???, loss: 6.770291328430176
epoch: 0, update in batch 87000/???, loss: 6.530011177062988
epoch: 0, update in batch 88000/???, loss: 6.317249298095703
epoch: 0, update in batch 89000/???, loss: 6.191559314727783
epoch: 0, update in batch 90000/???, loss: 5.79150390625
epoch: 0, update in batch 91000/???, loss: 6.356796741485596
epoch: 0, update in batch 92000/???, loss: 7.3577141761779785
epoch: 0, update in batch 93000/???, loss: 6.529308319091797
epoch: 0, update in batch 94000/???, loss: 7.740485191345215
epoch: 0, update in batch 95000/???, loss: 6.348109245300293
epoch: 0, update in batch 96000/???, loss: 6.032902717590332
epoch: 0, update in batch 97000/???, loss: 4.505112648010254
epoch: 0, update in batch 98000/???, loss: 6.946290493011475
epoch: 0, update in batch 99000/???, loss: 6.237973213195801
epoch: 0, update in batch 100000/???, loss: 6.963421821594238
epoch: 0, update in batch 101000/???, loss: 5.309017181396484
epoch: 0, update in batch 102000/???, loss: 6.242384910583496
epoch: 0, update in batch 103000/???, loss: 6.8203558921813965
epoch: 0, update in batch 104000/???, loss: 6.242025852203369
epoch: 0, update in batch 105000/???, loss: 6.765100002288818
epoch: 0, update in batch 106000/???, loss: 6.8838043212890625
epoch: 0, update in batch 107000/???, loss: 6.856662750244141
epoch: 0, update in batch 108000/???, loss: 6.379549503326416
epoch: 0, update in batch 109000/???, loss: 6.797707557678223
epoch: 0, update in batch 110000/???, loss: 7.2699689865112305
epoch: 0, update in batch 111000/???, loss: 7.040057182312012
epoch: 0, update in batch 112000/???, loss: 6.7861223220825195
epoch: 0, update in batch 113000/???, loss: 6.064489364624023
epoch: 0, update in batch 114000/???, loss: 6.095967769622803
epoch: 0, update in batch 115000/???, loss: 5.757347106933594
epoch: 0, update in batch 116000/???, loss: 6.529908657073975
epoch: 0, update in batch 117000/???, loss: 6.030801296234131
epoch: 0, update in batch 118000/???, loss: 6.179767608642578
epoch: 0, update in batch 119000/???, loss: 5.436234474182129
epoch: 0, update in batch 120000/???, loss: 7.342876434326172
epoch: 0, update in batch 121000/???, loss: 6.862719535827637
epoch: 0, update in batch 122000/???, loss: 6.491606712341309
epoch: 0, update in batch 123000/???, loss: 7.195406436920166
epoch: 0, update in batch 124000/???, loss: 5.481313228607178
epoch: 0, update in batch 125000/???, loss: 7.963885307312012
epoch: 0, update in batch 126000/???, loss: 6.479039669036865
epoch: 0, update in batch 127000/???, loss: 7.037934303283691
epoch: 0, update in batch 128000/???, loss: 5.903053283691406
epoch: 0, update in batch 129000/???, loss: 6.815878391265869
epoch: 0, update in batch 130000/???, loss: 6.497969150543213
epoch: 0, update in batch 131000/???, loss: 5.623625755310059
epoch: 0, update in batch 132000/???, loss: 7.118441104888916
epoch: 0, update in batch 133000/???, loss: 5.964345455169678
epoch: 0, update in batch 134000/???, loss: 6.112139701843262
epoch: 0, update in batch 135000/???, loss: 6.5865373611450195
epoch: 0, update in batch 136000/???, loss: 7.498536109924316
epoch: 0, update in batch 137000/???, loss: 7.124758243560791
epoch: 0, update in batch 138000/???, loss: 6.871796607971191
epoch: 0, update in batch 139000/???, loss: 5.8565263748168945
epoch: 0, update in batch 140000/???, loss: 5.723143577575684
epoch: 0, update in batch 141000/???, loss: 5.601426124572754
epoch: 0, update in batch 142000/???, loss: 5.495566368103027
epoch: 0, update in batch 143000/???, loss: 6.936192989349365
epoch: 0, update in batch 144000/???, loss: 6.1843671798706055
epoch: 0, update in batch 145000/???, loss: 6.886034965515137
epoch: 0, update in batch 146000/???, loss: 6.655320167541504
epoch: 0, update in batch 147000/???, loss: 6.46828556060791
epoch: 0, update in batch 148000/???, loss: 5.607057571411133
epoch: 0, update in batch 149000/???, loss: 7.182212829589844
epoch: 0, update in batch 150000/???, loss: 7.241323947906494
epoch: 0, update in batch 151000/???, loss: 7.308540344238281
epoch: 0, update in batch 152000/???, loss: 5.267911434173584
epoch: 0, update in batch 153000/???, loss: 5.895949363708496
epoch: 0, update in batch 154000/???, loss: 6.629178524017334
epoch: 0, update in batch 155000/???, loss: 4.9156012535095215
epoch: 0, update in batch 156000/???, loss: 7.181819915771484
epoch: 0, update in batch 157000/???, loss: 7.438391208648682
epoch: 0, update in batch 158000/???, loss: 6.406006813049316
epoch: 0, update in batch 159000/???, loss: 6.486207008361816
epoch: 0, update in batch 160000/???, loss: 7.041951656341553
epoch: 0, update in batch 161000/???, loss: 5.310082912445068
epoch: 0, update in batch 162000/???, loss: 6.9074387550354
epoch: 0, update in batch 163000/???, loss: 6.644919395446777
epoch: 0, update in batch 164000/???, loss: 6.011733055114746
epoch: 0, update in batch 165000/???, loss: 6.494180202484131
epoch: 0, update in batch 166000/???, loss: 5.390150547027588
epoch: 0, update in batch 167000/???, loss: 6.627297401428223
epoch: 0, update in batch 168000/???, loss: 6.9020209312438965
epoch: 0, update in batch 169000/???, loss: 7.317750453948975
epoch: 0, update in batch 170000/???, loss: 5.69993782043457
epoch: 0, update in batch 171000/???, loss: 6.658817291259766
epoch: 0, update in batch 172000/???, loss: 6.422945976257324
epoch: 0, update in batch 173000/???, loss: 5.822269439697266
epoch: 0, update in batch 174000/???, loss: 6.513391017913818
epoch: 0, update in batch 175000/???, loss: 5.886590957641602
epoch: 0, update in batch 176000/???, loss: 7.119387149810791
epoch: 0, update in batch 177000/???, loss: 6.933981418609619
epoch: 0, update in batch 178000/???, loss: 6.678143501281738
epoch: 0, update in batch 179000/???, loss: 6.890423774719238
epoch: 0, update in batch 180000/???, loss: 6.932961940765381
epoch: 0, update in batch 181000/???, loss: 6.650975704193115
epoch: 0, update in batch 182000/???, loss: 6.732748985290527
epoch: 0, update in batch 183000/???, loss: 6.064764976501465
epoch: 0, update in batch 184000/???, loss: 5.282295227050781
epoch: 0, update in batch 185000/???, loss: 6.569302558898926
epoch: 0, update in batch 186000/???, loss: 5.800485610961914
epoch: 0, update in batch 187000/???, loss: 6.175991058349609
epoch: 0, update in batch 188000/???, loss: 5.405575752258301
epoch: 0, update in batch 189000/???, loss: 6.191354751586914
epoch: 0, update in batch 190000/???, loss: 6.156663417816162
epoch: 0, update in batch 191000/???, loss: 6.937534332275391
epoch: 0, update in batch 192000/???, loss: 6.562686920166016
epoch: 0, update in batch 193000/???, loss: 6.639985084533691
epoch: 0, update in batch 194000/???, loss: 7.285438537597656
epoch: 0, update in batch 195000/???, loss: 6.528258323669434
epoch: 0, update in batch 196000/???, loss: 8.326434135437012
epoch: 0, update in batch 197000/???, loss: 6.781360626220703
epoch: 0, update in batch 198000/???, loss: 7.223299980163574
epoch: 0, update in batch 199000/???, loss: 6.411007881164551
epoch: 0, update in batch 200000/???, loss: 5.885635852813721
epoch: 0, update in batch 201000/???, loss: 5.706809043884277
epoch: 0, update in batch 202000/???, loss: 6.230217933654785
epoch: 0, update in batch 203000/???, loss: 7.056562900543213
epoch: 0, update in batch 204000/???, loss: 7.2273077964782715
epoch: 0, update in batch 205000/???, loss: 6.342462539672852
epoch: 0, update in batch 206000/???, loss: 6.556817054748535
epoch: 0, update in batch 207000/???, loss: 5.882349967956543
epoch: 0, update in batch 208000/???, loss: 6.755805015563965
epoch: 0, update in batch 209000/???, loss: 6.5045623779296875
epoch: 0, update in batch 210000/???, loss: 6.525590419769287
epoch: 0, update in batch 211000/???, loss: 6.49679708480835
epoch: 0, update in batch 212000/???, loss: 6.562323093414307
epoch: 0, update in batch 213000/???, loss: 5.227139472961426
epoch: 0, update in batch 214000/???, loss: 7.044825077056885
epoch: 0, update in batch 215000/???, loss: 6.002442359924316
epoch: 0, update in batch 216000/???, loss: 6.084803581237793
epoch: 0, update in batch 217000/???, loss: 7.425839900970459
epoch: 0, update in batch 218000/???, loss: 6.818853855133057
epoch: 0, update in batch 219000/???, loss: 7.0153374671936035
epoch: 0, update in batch 220000/???, loss: 6.219962120056152
epoch: 0, update in batch 221000/???, loss: 5.9975385665893555
epoch: 0, update in batch 222000/???, loss: 6.480047702789307
epoch: 0, update in batch 223000/???, loss: 6.405727386474609
epoch: 0, update in batch 224000/???, loss: 4.7763471603393555
epoch: 0, update in batch 225000/???, loss: 6.615710258483887
epoch: 0, update in batch 226000/???, loss: 6.385044574737549
epoch: 0, update in batch 227000/???, loss: 7.260453701019287
epoch: 0, update in batch 228000/???, loss: 6.9794135093688965
epoch: 0, update in batch 229000/???, loss: 6.235735893249512
epoch: 0, update in batch 230000/???, loss: 6.478426456451416
epoch: 0, update in batch 231000/???, loss: 6.181302547454834
epoch: 0, update in batch 232000/???, loss: 5.826043128967285
epoch: 0, update in batch 233000/???, loss: 5.9517011642456055
epoch: 0, update in batch 234000/???, loss: 8.0064697265625
epoch: 0, update in batch 235000/???, loss: 6.7822675704956055
epoch: 0, update in batch 236000/???, loss: 6.293349742889404
epoch: 0, update in batch 237000/???, loss: 6.442999362945557
epoch: 0, update in batch 238000/???, loss: 6.282561302185059
epoch: 0, update in batch 239000/???, loss: 7.166723728179932
epoch: 0, update in batch 240000/???, loss: 7.189905643463135
epoch: 0, update in batch 241000/???, loss: 8.462562561035156
epoch: 0, update in batch 242000/???, loss: 7.446291923522949
epoch: 0, update in batch 243000/???, loss: 6.382981777191162
epoch: 0, update in batch 244000/???, loss: 7.635994911193848
epoch: 0, update in batch 245000/???, loss: 6.635537147521973
epoch: 0, update in batch 246000/???, loss: 6.068560600280762
epoch: 0, update in batch 247000/???, loss: 6.193384170532227
epoch: 0, update in batch 248000/???, loss: 5.702363967895508
epoch: 0, update in batch 249000/???, loss: 6.09995174407959
epoch: 0, update in batch 250000/???, loss: 6.312221050262451
epoch: 0, update in batch 251000/???, loss: 5.853858470916748
epoch: 0, update in batch 252000/???, loss: 5.886989593505859
epoch: 0, update in batch 253000/???, loss: 5.801788330078125
epoch: 0, update in batch 254000/???, loss: 6.032407760620117
epoch: 0, update in batch 255000/???, loss: 7.480917453765869
epoch: 0, update in batch 256000/???, loss: 6.578718662261963
epoch: 0, update in batch 257000/???, loss: 6.344462871551514
epoch: 0, update in batch 258000/???, loss: 5.939858436584473
epoch: 0, update in batch 259000/???, loss: 5.181772232055664
epoch: 0, update in batch 260000/???, loss: 6.640598297119141
epoch: 0, update in batch 261000/???, loss: 7.189258575439453
epoch: 0, update in batch 262000/???, loss: 6.2269287109375
epoch: 0, update in batch 263000/???, loss: 5.8858795166015625
epoch: 0, update in batch 264000/???, loss: 6.333988666534424
epoch: 0, update in batch 265000/???, loss: 6.313681602478027
epoch: 0, update in batch 266000/???, loss: 5.485809803009033
epoch: 0, update in batch 267000/???, loss: 6.250800609588623
epoch: 0, update in batch 268000/???, loss: 6.676806449890137
epoch: 0, update in batch 269000/???, loss: 5.6487932205200195
epoch: 0, update in batch 270000/???, loss: 6.648938179016113
epoch: 0, update in batch 271000/???, loss: 6.26931095123291
epoch: 0, update in batch 272000/???, loss: 5.343636512756348
epoch: 0, update in batch 273000/???, loss: 7.051453590393066
epoch: 0, update in batch 274000/???, loss: 4.578436851501465
epoch: 0, update in batch 275000/???, loss: 5.400996685028076
epoch: 0, update in batch 276000/???, loss: 6.129047870635986
epoch: 0, update in batch 277000/???, loss: 7.549851894378662
epoch: 0, update in batch 278000/???, loss: 6.093559265136719
epoch: 0, update in batch 279000/???, loss: 5.6921467781066895
epoch: 0, update in batch 280000/???, loss: 5.789463996887207
epoch: 0, update in batch 281000/???, loss: 5.681942939758301
epoch: 0, update in batch 282000/???, loss: 6.750497341156006
epoch: 0, update in batch 283000/???, loss: 5.960292339324951
epoch: 0, update in batch 284000/???, loss: 6.160388469696045
epoch: 0, update in batch 285000/???, loss: 7.137685298919678
epoch: 0, update in batch 286000/???, loss: 7.7431464195251465
epoch: 0, update in batch 287000/???, loss: 5.229738712310791
epoch: 0, update in batch 288000/???, loss: 6.654232025146484
epoch: 0, update in batch 289000/???, loss: 6.229329586029053
epoch: 0, update in batch 290000/???, loss: 7.188180446624756
epoch: 0, update in batch 291000/???, loss: 6.244111061096191
epoch: 0, update in batch 292000/???, loss: 7.199154853820801
epoch: 0, update in batch 293000/???, loss: 7.1866865158081055
epoch: 0, update in batch 294000/???, loss: 6.574115753173828
epoch: 0, update in batch 295000/???, loss: 6.487138271331787
epoch: 0, update in batch 296000/???, loss: 5.813161849975586
epoch: 0, update in batch 297000/???, loss: 6.159414291381836
epoch: 0, update in batch 298000/???, loss: 7.256616115570068
epoch: 0, update in batch 299000/???, loss: 7.511231899261475
epoch: 0, update in batch 300000/???, loss: 6.148821830749512
epoch: 0, update in batch 301000/???, loss: 7.108969211578369
epoch: 0, update in batch 302000/???, loss: 6.528176307678223
epoch: 0, update in batch 303000/???, loss: 6.276839256286621
epoch: 0, update in batch 304000/???, loss: 6.484020233154297
epoch: 0, update in batch 305000/???, loss: 6.38557767868042
epoch: 0, update in batch 306000/???, loss: 7.068814754486084
epoch: 0, update in batch 307000/???, loss: 5.844017505645752
epoch: 0, update in batch 308000/???, loss: 4.25785493850708
epoch: 0, update in batch 309000/???, loss: 6.709985256195068
epoch: 0, update in batch 310000/???, loss: 6.543104648590088
epoch: 0, update in batch 311000/???, loss: 6.675828456878662
epoch: 0, update in batch 312000/???, loss: 5.82969856262207
epoch: 0, update in batch 313000/???, loss: 6.05246639251709
epoch: 0, update in batch 314000/???, loss: 7.2366156578063965
epoch: 0, update in batch 315000/???, loss: 5.039820194244385
epoch: 0, update in batch 316000/???, loss: 5.943173885345459
epoch: 0, update in batch 317000/???, loss: 6.2509002685546875
epoch: 0, update in batch 318000/???, loss: 6.451228141784668
epoch: 0, update in batch 319000/???, loss: 6.6381049156188965
epoch: 0, update in batch 320000/???, loss: 6.570329189300537
epoch: 0, update in batch 321000/???, loss: 5.376622200012207
epoch: 0, update in batch 322000/???, loss: 6.487462520599365
epoch: 0, update in batch 323000/???, loss: 6.676497459411621
epoch: 0, update in batch 324000/???, loss: 6.283420562744141
epoch: 0, update in batch 325000/???, loss: 6.164648532867432
epoch: 0, update in batch 326000/???, loss: 6.839153289794922
epoch: 0, update in batch 327000/???, loss: 6.435141086578369
epoch: 0, update in batch 328000/???, loss: 6.160590171813965
epoch: 0, update in batch 329000/???, loss: 5.876160621643066
epoch: 0, update in batch 330000/???, loss: 6.47445011138916
epoch: 0, update in batch 331000/???, loss: 6.294231414794922
epoch: 0, update in batch 332000/???, loss: 6.099027156829834
epoch: 0, update in batch 333000/???, loss: 6.986542701721191
epoch: 0, update in batch 334000/???, loss: 7.018263816833496
epoch: 0, update in batch 335000/???, loss: 6.906959533691406
epoch: 0, update in batch 336000/???, loss: 6.12356424331665
epoch: 0, update in batch 337000/???, loss: 6.316069602966309
epoch: 0, update in batch 338000/???, loss: 6.908566474914551
epoch: 0, update in batch 339000/???, loss: 5.628839492797852
epoch: 0, update in batch 340000/???, loss: 7.069979667663574
epoch: 0, update in batch 341000/???, loss: 5.350735187530518
epoch: 0, update in batch 342000/???, loss: 5.377245903015137
epoch: 0, update in batch 343000/???, loss: 5.2340989112854
epoch: 0, update in batch 344000/???, loss: 6.087491512298584
epoch: 0, update in batch 345000/???, loss: 6.162985801696777
epoch: 0, update in batch 346000/???, loss: 5.655491828918457
epoch: 0, update in batch 347000/???, loss: 5.311842918395996
epoch: 0, update in batch 348000/???, loss: 7.577170372009277
epoch: 0, update in batch 349000/???, loss: 6.730460166931152
epoch: 0, update in batch 350000/???, loss: 6.782231330871582
epoch: 0, update in batch 351000/???, loss: 6.789486885070801
epoch: 0, update in batch 352000/???, loss: 5.473587989807129
epoch: 0, update in batch 353000/???, loss: 5.531443119049072
epoch: 0, update in batch 354000/???, loss: 7.220989227294922
epoch: 0, update in batch 355000/???, loss: 5.954288005828857
epoch: 0, update in batch 356000/???, loss: 4.112783432006836
epoch: 0, update in batch 357000/???, loss: 5.409672737121582
epoch: 0, update in batch 358000/???, loss: 6.408724784851074
epoch: 0, update in batch 359000/???, loss: 6.744941711425781
epoch: 0, update in batch 360000/???, loss: 6.218225479125977
epoch: 0, update in batch 361000/???, loss: 6.071394920349121
epoch: 0, update in batch 362000/???, loss: 6.137121677398682
epoch: 0, update in batch 363000/???, loss: 5.876864433288574
epoch: 0, update in batch 364000/???, loss: 7.715426445007324
epoch: 0, update in batch 365000/???, loss: 6.217362880706787
epoch: 0, update in batch 366000/???, loss: 6.741396903991699
epoch: 0, update in batch 367000/???, loss: 6.4564313888549805
epoch: 0, update in batch 368000/???, loss: 6.994439601898193
epoch: 0, update in batch 369000/???, loss: 6.061278820037842
epoch: 0, update in batch 370000/???, loss: 4.894576549530029
epoch: 0, update in batch 371000/???, loss: 6.351264953613281
epoch: 0, update in batch 372000/???, loss: 6.826904296875
epoch: 0, update in batch 373000/???, loss: 6.090312480926514
epoch: 0, update in batch 374000/???, loss: 5.797528266906738
epoch: 0, update in batch 375000/???, loss: 7.3235673904418945
epoch: 0, update in batch 376000/???, loss: 5.5752973556518555
epoch: 0, update in batch 377000/???, loss: 6.29438591003418
epoch: 0, update in batch 378000/???, loss: 5.238917827606201
epoch: 0, update in batch 379000/???, loss: 5.542972564697266
epoch: 0, update in batch 380000/???, loss: 6.5024614334106445
epoch: 0, update in batch 381000/???, loss: 6.918997287750244
epoch: 0, update in batch 382000/???, loss: 5.694029331207275
epoch: 0, update in batch 383000/???, loss: 7.109190940856934
epoch: 0, update in batch 384000/???, loss: 5.214654445648193
epoch: 0, update in batch 385000/???, loss: 7.055975437164307
epoch: 0, update in batch 386000/???, loss: 6.443846225738525
epoch: 0, update in batch 387000/???, loss: 5.544674873352051
epoch: 0, update in batch 388000/???, loss: 6.936171531677246
epoch: 0, update in batch 389000/???, loss: 6.646860599517822
epoch: 0, update in batch 390000/???, loss: 6.193584442138672
epoch: 0, update in batch 391000/???, loss: 5.9077558517456055
epoch: 0, update in batch 392000/???, loss: 5.029908657073975
epoch: 0, update in batch 393000/???, loss: 6.725222587585449
epoch: 0, update in batch 394000/???, loss: 6.814855098724365
epoch: 0, update in batch 395000/???, loss: 7.396543979644775
epoch: 0, update in batch 396000/???, loss: 6.993375301361084
epoch: 0, update in batch 397000/???, loss: 6.224326133728027
epoch: 0, update in batch 398000/???, loss: 6.301025390625
epoch: 0, update in batch 399000/???, loss: 6.707190036773682
epoch: 0, update in batch 400000/???, loss: 6.7646660804748535
epoch: 0, update in batch 401000/???, loss: 5.308177471160889
epoch: 0, update in batch 402000/???, loss: 8.35996150970459
epoch: 0, update in batch 403000/???, loss: 5.825610160827637
epoch: 0, update in batch 404000/???, loss: 6.310220718383789
epoch: 0, update in batch 405000/???, loss: 5.759210109710693
epoch: 0, update in batch 406000/???, loss: 6.32699728012085
epoch: 0, update in batch 407000/???, loss: 5.659378528594971
epoch: 0, update in batch 408000/???, loss: 6.216103553771973
epoch: 0, update in batch 409000/???, loss: 5.666914463043213
epoch: 0, update in batch 410000/???, loss: 6.419122219085693
epoch: 0, update in batch 411000/???, loss: 5.372750282287598
epoch: 0, update in batch 412000/???, loss: 6.839580535888672
epoch: 0, update in batch 413000/???, loss: 6.7682647705078125
epoch: 0, update in batch 414000/???, loss: 5.951648235321045
epoch: 0, update in batch 415000/???, loss: 6.181953430175781
epoch: 0, update in batch 416000/???, loss: 5.475704669952393
epoch: 0, update in batch 417000/???, loss: 6.383082866668701
epoch: 0, update in batch 418000/???, loss: 6.8107590675354
epoch: 0, update in batch 419000/???, loss: 5.753104209899902
epoch: 0, update in batch 420000/???, loss: 5.320840835571289
epoch: 0, update in batch 421000/???, loss: 7.377203464508057
epoch: 0, update in batch 422000/???, loss: 6.5706048011779785
epoch: 0, update in batch 423000/???, loss: 5.032872676849365
epoch: 0, update in batch 424000/???, loss: 5.781243324279785
epoch: 0, update in batch 425000/???, loss: 6.160118579864502
def predict_probs(tokens):
    model.eval()
    state_h = model.init_state(len(tokens))

    x = torch.tensor([[train_dataset.key_to_index[w] if w in key_to_index else train_dataset.key_to_index['<unk>'] for w in tokens]]).to(device)
    y_pred, state_h = model(x)

    last_word_logits = y_pred[0][-1]
    probs = torch.nn.functional.softmax(last_word_logits, dim=0).detach().cpu().numpy()
    word_index = np.random.choice(len(last_word_logits), p=probs)
    
    top_words = []
    for index in range(len(probs)):
        if len(top_words) < 30:
            top_words.append((probs[index], [index]))
        else:
            worst_word = None
            for word in top_words:
                if not worst_word:
                    worst_word = word
                else:
                    if word[0] < worst_word[0]:
                        worst_word = word
            if worst_word[0] < probs[index] and index != len(probs) - 1:
                top_words.remove(worst_word)
                top_words.append((probs[index], [index]))
                
    prediction = ''
    sum_prob = 0.0
    for word in top_words:
        sum_prob += word[0]
        word_index = word[0]
        word_text = index_to_key[word[1][0]]
        prediction += f'{word_text}:{word_index} '
    prediction += f':{1 - sum_prob}'
    
    return prediction
dev_data = pd.read_csv('dev-0/in.tsv.xz', sep='\t', on_bad_lines='skip', header=None, quoting=csv.QUOTE_NONE)
test_data = pd.read_csv('test-A/in.tsv.xz', sep='\t', on_bad_lines='skip', header=None, quoting=csv.QUOTE_NONE)
with open('dev-0/out.tsv', 'w') as file:
    for index, row in dev_data.iterrows():
        left_text = clean_text(str(row[6]))
        left_words = word_tokenize(left_text)
        if len(left_words) < 6:
            prediction = ':1.0'
        else:
            prediction = predict_probs(left_words[-5:])
        file.write(prediction + '\n')
with open('test-A/out.tsv', 'w') as file:
    for index, row in test_data.iterrows():
        left_text = clean_text(str(row[6]))
        left_words = word_tokenize(left_text)
        if len(left_words) < 6:
            prediction = ':1.0'
        else:
            prediction = predict_probs(left_words[-5:])
        file.write(prediction + '\n')