challenging-america-word-ga.../run.ipynb
2022-06-06 11:25:46 +02:00

36 KiB

import pandas as pd
import numpy as np
import regex as re
import csv
import torch
from torch import nn
from gensim.models import Word2Vec
from nltk.tokenize import word_tokenize
def clean_text(text):
    text = text.lower().replace('-\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\n', '').replace('\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\n', ' ')
    text = re.sub(r'\p{P}', '', text)
    text = text.replace("'t", " not").replace("'s", " is").replace("'ll", " will").replace("'m", " am").replace("'ve", " have")

    return text
train_data = pd.read_csv('train/in.tsv.xz', sep='\t', on_bad_lines='skip', header=None, quoting=csv.QUOTE_NONE)
train_labels = pd.read_csv('train/expected.tsv', sep='\t', on_bad_lines='skip', header=None, quoting=csv.QUOTE_NONE)

train_data = train_data[[6, 7]]
train_data = pd.concat([train_data, train_labels], axis=1)
class TrainCorpus:
    def __init__(self, data):
        self.data = data
        
    def __iter__(self):
        for _, row in self.data.iterrows():
            text = str(row[6]) + str(row[0]) + str(row[7])
            text = clean_text(text)
            yield word_tokenize(text)
train_sentences = TrainCorpus(train_data.head(80000))
w2v_model = Word2Vec(vector_size=100, min_count=10)
w2v_model.build_vocab(corpus_iterable=train_sentences)

key_to_index = w2v_model.wv.key_to_index
index_to_key = w2v_model.wv.index_to_key

index_to_key.append('<unk>')
key_to_index['<unk>'] = len(index_to_key) - 1

vocab_size = len(index_to_key)
print(vocab_size)
81477
class TrainDataset(torch.utils.data.IterableDataset):
    def __init__(self, data, index_to_key, key_to_index, reversed=False):
        self.reversed = reversed
        self.data = data
        self.index_to_key = index_to_key
        self.key_to_index = key_to_index
        self.vocab_size = len(key_to_index)

    def __iter__(self):
        for _, row in self.data.iterrows():
            text = str(row[6]) + str(row[0]) + str(row[7])
            text = clean_text(text)
            tokens = word_tokenize(text)
            if self.reversed:
                tokens = list(reversed(tokens))
            for i in range(5, len(tokens), 1):
                input_context = tokens[i-5:i]
                target_context = tokens[i-4:i+1]
            
                input_embed = [self.key_to_index[word] if word in self.key_to_index else self.key_to_index['<unk>'] for word in input_context]
                target_embed = [self.key_to_index[word] if word in self.key_to_index else self.key_to_index['<unk>'] for word in target_context]
                
                yield np.asarray(input_embed, dtype=np.int64), np.asarray(target_embed, dtype=np.int64)
class Model(nn.Module):
    def __init__(self, embed_size, vocab_size):
        super(Model, self).__init__()
        self.embed_size = embed_size
        self.vocab_size = vocab_size
        self.lstm_size = 128
        self.num_layers = 2
        
        self.embed = nn.Embedding(num_embeddings=vocab_size, embedding_dim=self.embed_size)
        self.lstm = nn.LSTM(input_size=self.embed_size, hidden_size=self.lstm_size, num_layers=self.num_layers, dropout=0.2)
        self.fc = nn.Linear(self.lstm_size, vocab_size)

    def forward(self, x, prev_state = None):
        embed = self.embed(x)
        output, state = self.lstm(embed, prev_state)
        logits = self.fc(output)
        probs = torch.softmax(logits, dim=1)
        return logits, state

    def init_state(self, sequence_length):
        zeros = torch.zeros(self.num_layers, sequence_length, self.gru_size).to(device)
        return (zeros, zeros)
from torch.utils.data import DataLoader
from torch.optim import Adam

def train(dataset, model, max_epochs, batch_size):
    model.train()

    dataloader = DataLoader(dataset, batch_size=batch_size)
    criterion = nn.CrossEntropyLoss()
    optimizer = Adam(model.parameters(), lr=0.001)

    for epoch in range(max_epochs):
        for batch, (x, y) in enumerate(dataloader):
            optimizer.zero_grad()
            
            x = x.to(device)
            y = y.to(device)
            
            y_pred, (state_h, state_c) = model(x)
            loss = criterion(y_pred.transpose(1, 2), y)

            loss.backward()
            optimizer.step()
            
            if batch % 100 == 0:
                print(f'epoch: {epoch}, update in batch {batch}/???, loss: {loss.item()}')
torch.cuda.empty_cache()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
train_dataset_front = TrainDataset(train_data.head(8000), index_to_key, key_to_index, False)
model_front = Model(100, vocab_size).to(device)
train(train_dataset_front, model_front, 1, 64)
epoch: 0, update in batch 0/???, loss: 11.315739631652832
epoch: 0, update in batch 100/???, loss: 8.016324996948242
epoch: 0, update in batch 200/???, loss: 7.45602560043335
epoch: 0, update in batch 300/???, loss: 6.306332588195801
epoch: 0, update in batch 400/???, loss: 8.629552841186523
epoch: 0, update in batch 500/???, loss: 7.637443542480469
epoch: 0, update in batch 600/???, loss: 7.67318868637085
epoch: 0, update in batch 700/???, loss: 7.2209930419921875
epoch: 0, update in batch 800/???, loss: 7.739532470703125
epoch: 0, update in batch 900/???, loss: 7.219891548156738
epoch: 0, update in batch 1000/???, loss: 6.8804473876953125
epoch: 0, update in batch 1100/???, loss: 7.228173732757568
epoch: 0, update in batch 1200/???, loss: 6.513087272644043
epoch: 0, update in batch 1300/???, loss: 7.142991542816162
epoch: 0, update in batch 1400/???, loss: 7.711663246154785
epoch: 0, update in batch 1500/???, loss: 6.894327640533447
epoch: 0, update in batch 1600/???, loss: 7.723884582519531
epoch: 0, update in batch 1700/???, loss: 8.409640312194824
epoch: 0, update in batch 1800/???, loss: 6.570927619934082
epoch: 0, update in batch 1900/???, loss: 6.906421661376953
epoch: 0, update in batch 2000/???, loss: 7.197023868560791
epoch: 0, update in batch 2100/???, loss: 6.892503261566162
epoch: 0, update in batch 2200/???, loss: 7.109471321105957
epoch: 0, update in batch 2300/???, loss: 8.84702205657959
epoch: 0, update in batch 2400/???, loss: 7.394454002380371
epoch: 0, update in batch 2500/???, loss: 7.380859375
epoch: 0, update in batch 2600/???, loss: 6.635237693786621
epoch: 0, update in batch 2700/???, loss: 6.869620323181152
epoch: 0, update in batch 2800/???, loss: 6.656294822692871
epoch: 0, update in batch 2900/???, loss: 8.090291976928711
epoch: 0, update in batch 3000/???, loss: 7.012345314025879
epoch: 0, update in batch 3100/???, loss: 6.7099809646606445
epoch: 0, update in batch 3200/???, loss: 6.798626899719238
epoch: 0, update in batch 3300/???, loss: 6.510752201080322
epoch: 0, update in batch 3400/???, loss: 7.742552757263184
epoch: 0, update in batch 3500/???, loss: 7.3319292068481445
epoch: 0, update in batch 3600/???, loss: 8.022462844848633
epoch: 0, update in batch 3700/???, loss: 5.883602619171143
epoch: 0, update in batch 3800/???, loss: 6.235389232635498
epoch: 0, update in batch 3900/???, loss: 7.012289524078369
epoch: 0, update in batch 4000/???, loss: 7.005420684814453
epoch: 0, update in batch 4100/???, loss: 6.595402717590332
epoch: 0, update in batch 4200/???, loss: 6.7428154945373535
epoch: 0, update in batch 4300/???, loss: 6.358878135681152
epoch: 0, update in batch 4400/???, loss: 6.6188201904296875
epoch: 0, update in batch 4500/???, loss: 7.08281946182251
epoch: 0, update in batch 4600/???, loss: 5.705609321594238
epoch: 0, update in batch 4700/???, loss: 7.1878180503845215
epoch: 0, update in batch 4800/???, loss: 7.071160793304443
epoch: 0, update in batch 4900/???, loss: 6.768280029296875
epoch: 0, update in batch 5000/???, loss: 6.507267951965332
epoch: 0, update in batch 5100/???, loss: 6.6431379318237305
epoch: 0, update in batch 5200/???, loss: 6.719052314758301
epoch: 0, update in batch 5300/???, loss: 7.172060489654541
epoch: 0, update in batch 5400/???, loss: 5.98638916015625
epoch: 0, update in batch 5500/???, loss: 5.674165725708008
epoch: 0, update in batch 5600/???, loss: 5.612569808959961
epoch: 0, update in batch 5700/???, loss: 6.307109832763672
epoch: 0, update in batch 5800/???, loss: 5.382391452789307
epoch: 0, update in batch 5900/???, loss: 5.712988376617432
epoch: 0, update in batch 6000/???, loss: 6.371735572814941
epoch: 0, update in batch 6100/???, loss: 6.417542457580566
epoch: 0, update in batch 6200/???, loss: 7.14879846572876
epoch: 0, update in batch 6300/???, loss: 7.0701189041137695
epoch: 0, update in batch 6400/???, loss: 7.048495292663574
epoch: 0, update in batch 6500/???, loss: 7.3384833335876465
epoch: 0, update in batch 6600/???, loss: 6.561330318450928
epoch: 0, update in batch 6700/???, loss: 6.839573860168457
epoch: 0, update in batch 6800/???, loss: 6.5179548263549805
epoch: 0, update in batch 6900/???, loss: 7.246607303619385
epoch: 0, update in batch 7000/???, loss: 6.5699052810668945
epoch: 0, update in batch 7100/???, loss: 7.202715873718262
epoch: 0, update in batch 7200/???, loss: 6.1833648681640625
epoch: 0, update in batch 7300/???, loss: 5.977782249450684
epoch: 0, update in batch 7400/???, loss: 6.717446327209473
epoch: 0, update in batch 7500/???, loss: 6.574376583099365
epoch: 0, update in batch 7600/???, loss: 5.8418450355529785
epoch: 0, update in batch 7700/???, loss: 6.282655715942383
epoch: 0, update in batch 7800/???, loss: 6.065321922302246
epoch: 0, update in batch 7900/???, loss: 6.415077209472656
epoch: 0, update in batch 8000/???, loss: 6.482673645019531
epoch: 0, update in batch 8100/???, loss: 6.670407772064209
epoch: 0, update in batch 8200/???, loss: 6.799211025238037
epoch: 0, update in batch 8300/???, loss: 7.299313545227051
epoch: 0, update in batch 8400/???, loss: 7.42974328994751
epoch: 0, update in batch 8500/???, loss: 8.549559593200684
epoch: 0, update in batch 8600/???, loss: 6.794680118560791
epoch: 0, update in batch 8700/???, loss: 7.390380859375
epoch: 0, update in batch 8800/???, loss: 7.552660942077637
epoch: 0, update in batch 8900/???, loss: 6.663547515869141
epoch: 0, update in batch 9000/???, loss: 6.5236711502075195
epoch: 0, update in batch 9100/???, loss: 7.666424751281738
epoch: 0, update in batch 9200/???, loss: 6.479496955871582
epoch: 0, update in batch 9300/???, loss: 5.5056304931640625
epoch: 0, update in batch 9400/???, loss: 6.6904096603393555
epoch: 0, update in batch 9500/???, loss: 6.9318037033081055
epoch: 0, update in batch 9600/???, loss: 6.521365165710449
epoch: 0, update in batch 9700/???, loss: 6.376631736755371
epoch: 0, update in batch 9800/???, loss: 6.4104766845703125
epoch: 0, update in batch 9900/???, loss: 7.3995232582092285
epoch: 0, update in batch 10000/???, loss: 6.510337829589844
epoch: 0, update in batch 10100/???, loss: 6.2512407302856445
epoch: 0, update in batch 10200/???, loss: 6.048404216766357
epoch: 0, update in batch 10300/???, loss: 6.832150936126709
epoch: 0, update in batch 10400/???, loss: 6.7485456466674805
epoch: 0, update in batch 10500/???, loss: 5.385656833648682
epoch: 0, update in batch 10600/???, loss: 6.769070625305176
epoch: 0, update in batch 10700/???, loss: 6.857029914855957
epoch: 0, update in batch 10800/???, loss: 5.991332530975342
epoch: 0, update in batch 10900/???, loss: 6.5500006675720215
epoch: 0, update in batch 11000/???, loss: 6.951509952545166
epoch: 0, update in batch 11100/???, loss: 6.396986961364746
epoch: 0, update in batch 11200/???, loss: 6.639346122741699
epoch: 0, update in batch 11300/???, loss: 5.87351655960083
epoch: 0, update in batch 11400/???, loss: 5.996974945068359
epoch: 0, update in batch 11500/???, loss: 7.103158473968506
epoch: 0, update in batch 11600/???, loss: 6.429941654205322
epoch: 0, update in batch 11700/???, loss: 5.597273826599121
epoch: 0, update in batch 11800/???, loss: 7.112508296966553
epoch: 0, update in batch 11900/???, loss: 6.745194911956787
epoch: 0, update in batch 12000/???, loss: 7.47100305557251
epoch: 0, update in batch 12100/???, loss: 6.847914695739746
epoch: 0, update in batch 12200/???, loss: 6.876992702484131
epoch: 0, update in batch 12300/???, loss: 6.499053955078125
epoch: 0, update in batch 12400/???, loss: 7.196413993835449
epoch: 0, update in batch 12500/???, loss: 6.593430995941162
epoch: 0, update in batch 12600/???, loss: 6.368945121765137
epoch: 0, update in batch 12700/???, loss: 6.362246513366699
epoch: 0, update in batch 12800/???, loss: 7.209506034851074
epoch: 0, update in batch 12900/???, loss: 6.8092780113220215
epoch: 0, update in batch 13000/???, loss: 8.273663520812988
epoch: 0, update in batch 13100/???, loss: 7.061187744140625
epoch: 0, update in batch 13200/???, loss: 5.778809547424316
epoch: 0, update in batch 13300/???, loss: 5.650263786315918
epoch: 0, update in batch 13400/???, loss: 5.9032440185546875
epoch: 0, update in batch 13500/???, loss: 6.629636287689209
epoch: 0, update in batch 13600/???, loss: 6.577019691467285
epoch: 0, update in batch 13700/???, loss: 5.953114032745361
epoch: 0, update in batch 13800/???, loss: 6.630902290344238
epoch: 0, update in batch 13900/???, loss: 7.593966484069824
epoch: 0, update in batch 14000/???, loss: 6.636081695556641
epoch: 0, update in batch 14100/???, loss: 5.772985458374023
epoch: 0, update in batch 14200/???, loss: 5.907249450683594
epoch: 0, update in batch 14300/???, loss: 7.863391876220703
epoch: 0, update in batch 14400/???, loss: 7.275572776794434
epoch: 0, update in batch 14500/???, loss: 6.818984031677246
epoch: 0, update in batch 14600/???, loss: 6.0456342697143555
epoch: 0, update in batch 14700/???, loss: 6.281990051269531
epoch: 0, update in batch 14800/???, loss: 6.197850227355957
epoch: 0, update in batch 14900/???, loss: 5.851240634918213
epoch: 0, update in batch 15000/???, loss: 6.826748847961426
epoch: 0, update in batch 15100/???, loss: 7.2189483642578125
epoch: 0, update in batch 15200/???, loss: 6.609204292297363
epoch: 0, update in batch 15300/???, loss: 6.947709560394287
epoch: 0, update in batch 15400/???, loss: 6.604478359222412
epoch: 0, update in batch 15500/???, loss: 6.222006797790527
epoch: 0, update in batch 15600/???, loss: 6.515635013580322
epoch: 0, update in batch 15700/???, loss: 6.40108585357666
epoch: 0, update in batch 15800/???, loss: 6.36106014251709
epoch: 0, update in batch 15900/???, loss: 6.533608436584473
epoch: 0, update in batch 16000/???, loss: 6.662516117095947
epoch: 0, update in batch 16100/???, loss: 7.284195899963379
epoch: 0, update in batch 16200/???, loss: 6.6524176597595215
epoch: 0, update in batch 16300/???, loss: 6.430756568908691
epoch: 0, update in batch 16400/???, loss: 7.515387058258057
epoch: 0, update in batch 16500/???, loss: 6.938241481781006
epoch: 0, update in batch 16600/???, loss: 5.860864162445068
epoch: 0, update in batch 16700/???, loss: 6.451329231262207
epoch: 0, update in batch 16800/???, loss: 6.5510663986206055
epoch: 0, update in batch 16900/???, loss: 7.3591437339782715
epoch: 0, update in batch 17000/???, loss: 6.158746719360352
epoch: 0, update in batch 17100/???, loss: 7.202520847320557
epoch: 0, update in batch 17200/???, loss: 6.80673885345459
epoch: 0, update in batch 17300/???, loss: 6.698304653167725
epoch: 0, update in batch 17400/???, loss: 5.743161201477051
epoch: 0, update in batch 17500/???, loss: 6.518529415130615
epoch: 0, update in batch 17600/???, loss: 6.021708011627197
epoch: 0, update in batch 17700/???, loss: 6.354712963104248
epoch: 0, update in batch 17800/???, loss: 6.323357582092285
epoch: 0, update in batch 17900/???, loss: 6.61548376083374
epoch: 0, update in batch 18000/???, loss: 6.600308895111084
epoch: 0, update in batch 18100/???, loss: 6.794068336486816
epoch: 0, update in batch 18200/???, loss: 7.487390041351318
epoch: 0, update in batch 18300/???, loss: 5.973461627960205
epoch: 0, update in batch 18400/???, loss: 6.891515254974365
epoch: 0, update in batch 18500/???, loss: 5.897144317626953
epoch: 0, update in batch 18600/???, loss: 6.6016364097595215
epoch: 0, update in batch 18700/???, loss: 6.948650360107422
epoch: 0, update in batch 18800/???, loss: 7.221627235412598
epoch: 0, update in batch 18900/???, loss: 6.817994117736816
epoch: 0, update in batch 19000/???, loss: 5.730655193328857
epoch: 0, update in batch 19100/???, loss: 6.236818790435791
epoch: 0, update in batch 19200/???, loss: 7.178666114807129
epoch: 0, update in batch 19300/???, loss: 6.77465295791626
epoch: 0, update in batch 19400/???, loss: 6.996792793273926
epoch: 0, update in batch 19500/???, loss: 6.80951452255249
epoch: 0, update in batch 19600/???, loss: 7.1757965087890625
epoch: 0, update in batch 19700/???, loss: 8.400952339172363
epoch: 0, update in batch 19800/???, loss: 7.1904473304748535
epoch: 0, update in batch 19900/???, loss: 6.339241981506348
epoch: 0, update in batch 20000/???, loss: 7.078637599945068
epoch: 0, update in batch 20100/???, loss: 5.015235900878906
epoch: 0, update in batch 20200/???, loss: 6.763777732849121
epoch: 0, update in batch 20300/???, loss: 6.543915748596191
epoch: 0, update in batch 20400/???, loss: 6.027902603149414
epoch: 0, update in batch 20500/???, loss: 6.710694789886475
epoch: 0, update in batch 20600/???, loss: 6.800978660583496
epoch: 0, update in batch 20700/???, loss: 6.371827125549316
epoch: 0, update in batch 20800/???, loss: 5.952463626861572
epoch: 0, update in batch 20900/???, loss: 6.317960739135742
epoch: 0, update in batch 21000/???, loss: 7.178386688232422
epoch: 0, update in batch 21100/???, loss: 6.887454986572266
epoch: 0, update in batch 21200/???, loss: 6.468400478363037
epoch: 0, update in batch 21300/???, loss: 7.8383684158325195
epoch: 0, update in batch 21400/???, loss: 5.850740909576416
epoch: 0, update in batch 21500/???, loss: 6.065464973449707
epoch: 0, update in batch 21600/???, loss: 7.537625312805176
epoch: 0, update in batch 21700/???, loss: 6.095994472503662
epoch: 0, update in batch 21800/???, loss: 6.342766761779785
epoch: 0, update in batch 21900/???, loss: 5.810301780700684
epoch: 0, update in batch 22000/???, loss: 6.447206974029541
epoch: 0, update in batch 22100/???, loss: 7.0662946701049805
epoch: 0, update in batch 22200/???, loss: 6.535088539123535
epoch: 0, update in batch 22300/???, loss: 7.017588138580322
epoch: 0, update in batch 22400/???, loss: 5.067782402038574
epoch: 0, update in batch 22500/???, loss: 6.493170738220215
epoch: 0, update in batch 22600/???, loss: 5.642627716064453
epoch: 0, update in batch 22700/???, loss: 7.200662136077881
epoch: 0, update in batch 22800/???, loss: 6.137134075164795
epoch: 0, update in batch 22900/???, loss: 6.367280006408691
epoch: 0, update in batch 23000/???, loss: 7.458652496337891
epoch: 0, update in batch 23100/???, loss: 6.515708923339844
epoch: 0, update in batch 23200/???, loss: 7.526422023773193
epoch: 0, update in batch 23300/???, loss: 6.653852939605713
epoch: 0, update in batch 23400/???, loss: 6.737251281738281
epoch: 0, update in batch 23500/???, loss: 6.493605136871338
epoch: 0, update in batch 23600/???, loss: 6.132809638977051
epoch: 0, update in batch 23700/???, loss: 6.406940460205078
epoch: 0, update in batch 23800/???, loss: 6.84005880355835
epoch: 0, update in batch 23900/???, loss: 6.830739498138428
epoch: 0, update in batch 24000/???, loss: 5.862464427947998
epoch: 0, update in batch 24100/???, loss: 6.382696628570557
epoch: 0, update in batch 24200/???, loss: 5.722895622253418
epoch: 0, update in batch 24300/???, loss: 6.697083473205566
epoch: 0, update in batch 24400/???, loss: 6.56771183013916
epoch: 0, update in batch 24500/???, loss: 7.566462516784668
epoch: 0, update in batch 24600/???, loss: 6.217026710510254
epoch: 0, update in batch 24700/???, loss: 7.164259433746338
epoch: 0, update in batch 24800/???, loss: 6.460946083068848
epoch: 0, update in batch 24900/???, loss: 6.333778381347656
epoch: 0, update in batch 25000/???, loss: 6.522342681884766
epoch: 0, update in batch 25100/???, loss: 6.270648002624512
epoch: 0, update in batch 25200/???, loss: 7.118265628814697
epoch: 0, update in batch 25300/???, loss: 5.8695197105407715
epoch: 0, update in batch 25400/???, loss: 5.92995023727417
epoch: 0, update in batch 25500/???, loss: 6.202570915222168
epoch: 0, update in batch 25600/???, loss: 6.4268975257873535
epoch: 0, update in batch 25700/???, loss: 6.710567474365234
epoch: 0, update in batch 25800/???, loss: 6.130914688110352
epoch: 0, update in batch 25900/???, loss: 6.082686424255371
epoch: 0, update in batch 26000/???, loss: 6.111697196960449
epoch: 0, update in batch 26100/???, loss: 7.320557594299316
epoch: 0, update in batch 26200/???, loss: 6.227985858917236
epoch: 0, update in batch 26300/???, loss: 6.204974174499512
epoch: 0, update in batch 26400/???, loss: 6.658400058746338
epoch: 0, update in batch 26500/???, loss: 5.911742687225342
epoch: 0, update in batch 26600/???, loss: 6.891500949859619
epoch: 0, update in batch 26700/???, loss: 5.763737201690674
epoch: 0, update in batch 26800/???, loss: 5.757307529449463
epoch: 0, update in batch 26900/???, loss: 6.076601982116699
epoch: 0, update in batch 27000/???, loss: 6.193032264709473
epoch: 0, update in batch 27100/???, loss: 6.120661735534668
epoch: 0, update in batch 27200/???, loss: 6.5425519943237305
epoch: 0, update in batch 27300/???, loss: 6.511394500732422
epoch: 0, update in batch 27400/???, loss: 7.127263069152832
epoch: 0, update in batch 27500/???, loss: 6.134243488311768
epoch: 0, update in batch 27600/???, loss: 6.5747809410095215
epoch: 0, update in batch 27700/???, loss: 6.351634979248047
epoch: 0, update in batch 27800/???, loss: 5.589611530303955
epoch: 0, update in batch 27900/???, loss: 6.916817665100098
epoch: 0, update in batch 28000/???, loss: 5.711864948272705
epoch: 0, update in batch 28100/???, loss: 6.921398162841797
epoch: 0, update in batch 28200/???, loss: 6.785823822021484
epoch: 0, update in batch 28300/???, loss: 6.007838249206543
epoch: 0, update in batch 28400/???, loss: 6.338862419128418
epoch: 0, update in batch 28500/???, loss: 6.9078168869018555
epoch: 0, update in batch 28600/???, loss: 6.710842132568359
epoch: 0, update in batch 28700/???, loss: 6.592329502105713
epoch: 0, update in batch 28800/???, loss: 6.184128761291504
epoch: 0, update in batch 28900/???, loss: 6.209361553192139
epoch: 0, update in batch 29000/???, loss: 7.067984104156494
epoch: 0, update in batch 29100/???, loss: 6.479236602783203
epoch: 0, update in batch 29200/???, loss: 6.413198947906494
epoch: 0, update in batch 29300/???, loss: 6.638579368591309
epoch: 0, update in batch 29400/???, loss: 5.938233375549316
epoch: 0, update in batch 29500/???, loss: 6.8490891456604
epoch: 0, update in batch 29600/???, loss: 6.111110210418701
epoch: 0, update in batch 29700/???, loss: 6.959462642669678
epoch: 0, update in batch 29800/???, loss: 6.964720726013184
epoch: 0, update in batch 29900/???, loss: 6.2007527351379395
epoch: 0, update in batch 30000/???, loss: 6.803907871246338
epoch: 0, update in batch 30100/???, loss: 5.665301322937012
epoch: 0, update in batch 30200/???, loss: 6.913702487945557
epoch: 0, update in batch 30300/???, loss: 6.824265956878662
epoch: 0, update in batch 30400/???, loss: 6.131905555725098
epoch: 0, update in batch 30500/???, loss: 5.799595832824707
epoch: 0, update in batch 30600/???, loss: 6.846949100494385
epoch: 0, update in batch 30700/???, loss: 6.481771945953369
epoch: 0, update in batch 30800/???, loss: 6.5581254959106445
epoch: 0, update in batch 30900/???, loss: 6.111696720123291
epoch: 0, update in batch 31000/???, loss: 4.8547563552856445
epoch: 0, update in batch 31100/???, loss: 6.5503740310668945
epoch: 0, update in batch 31200/???, loss: 6.212404251098633
epoch: 0, update in batch 31300/???, loss: 5.761624336242676
epoch: 0, update in batch 31400/???, loss: 7.043508052825928
epoch: 0, update in batch 31500/???, loss: 8.301980018615723
epoch: 0, update in batch 31600/???, loss: 5.655745506286621
epoch: 0, update in batch 31700/???, loss: 7.116888999938965
epoch: 0, update in batch 31800/???, loss: 6.237078666687012
epoch: 0, update in batch 31900/???, loss: 6.990937232971191
epoch: 0, update in batch 32000/???, loss: 6.327075958251953
epoch: 0, update in batch 32100/???, loss: 6.831456184387207
epoch: 0, update in batch 32200/???, loss: 6.511493682861328
epoch: 0, update in batch 32300/???, loss: 6.719797611236572
epoch: 0, update in batch 32400/???, loss: 6.46258544921875
epoch: 0, update in batch 32500/???, loss: 7.349535942077637
epoch: 0, update in batch 32600/???, loss: 5.773186683654785
epoch: 0, update in batch 32700/???, loss: 6.072037696838379
epoch: 0, update in batch 32800/???, loss: 7.044579982757568
epoch: 0, update in batch 32900/???, loss: 6.290024757385254
epoch: 0, update in batch 33000/???, loss: 7.101686000823975
epoch: 0, update in batch 33100/???, loss: 6.590539455413818
epoch: 0, update in batch 33200/???, loss: 6.944089412689209
epoch: 0, update in batch 33300/???, loss: 6.6709442138671875
epoch: 0, update in batch 33400/???, loss: 7.119935035705566
epoch: 0, update in batch 33500/???, loss: 6.845646858215332
epoch: 0, update in batch 33600/???, loss: 6.941410064697266
epoch: 0, update in batch 33700/???, loss: 6.341822624206543
epoch: 0, update in batch 33800/???, loss: 6.98660945892334
epoch: 0, update in batch 33900/???, loss: 7.544371128082275
epoch: 0, update in batch 34000/???, loss: 6.844598293304443
epoch: 0, update in batch 34100/???, loss: 6.958268642425537
epoch: 0, update in batch 34200/???, loss: 6.6372880935668945
def predict_probs(left_tokens, right_tokens):
    model_front.eval()

    x_left = torch.tensor([[train_dataset_front.key_to_index[w] if w in key_to_index else train_dataset_front.key_to_index['<unk>'] for w in left_tokens]]).to(device)
    x_right = torch.tensor([[train_dataset_front.key_to_index[w] if w in key_to_index else train_dataset_front.key_to_index['<unk>'] for w in right_tokens]]).to(device)
    y_pred_left, (state_h_left, state_c_left) = model_front(x_left)
    y_pred_right, (state_h_right, state_c_right) = model_back(x_right)

    last_word_logits_left = y_pred_left[0][-1]
    last_word_logits_right = y_pred_right[0][-1]
    probs_left = torch.nn.functional.softmax(last_word_logits_left, dim=0).detach().cpu().numpy()
    probs_right = torch.nn.functional.softmax(last_word_logits_right, dim=0).detach().cpu().numpy()
    
    probs = [np.mean(k) for k in zip(probs_left, probs_right)]
    
    top_words = []
    for index in range(len(probs)):
        if len(top_words) < 30:
            top_words.append((probs[index], [index]))
        else:
            worst_word = None
            for word in top_words:
                if not worst_word:
                    worst_word = word
                else:
                    if word[0] < worst_word[0]:
                        worst_word = word
            if worst_word[0] < probs[index] and index != len(probs) - 1:
                top_words.remove(worst_word)
                top_words.append((probs[index], [index]))
                
    prediction = ''
    sum_prob = 0.0
    for word in top_words:
        sum_prob += word[0]
        word_index = word[0]
        word_text = index_to_key[word[1][0]]
        prediction += f'{word_text}:{word_index} '
    prediction += f':{1 - sum_prob}'
    
    return prediction
dev_data = pd.read_csv('dev-0/in.tsv.xz', sep='\t', on_bad_lines='skip', header=None, quoting=csv.QUOTE_NONE)
test_data = pd.read_csv('test-A/in.tsv.xz', sep='\t', on_bad_lines='skip', header=None, quoting=csv.QUOTE_NONE)
with open('dev-0/out.tsv', 'w') as file:
    for index, row in dev_data.iterrows():
        left_text = clean_text(str(row[6]))
        right_text = clean_text(str(row[7]))
        left_words = word_tokenize(left_text)
        right_words = word_tokenize(right_text)
        right_words.reverse()
        if len(left_words) < 6 or len(right_words) < 6:
            prediction = ':1.0'
        else:
            prediction = predict_probs(left_words[-5:], right_words[-5:])
        file.write(prediction + '\n')
with open('test-A/out.tsv', 'w') as file:
    for index, row in test_data.iterrows():
        left_text = clean_text(str(row[6]))
        right_text = clean_text(str(row[7]))
        left_words = word_tokenize(left_text)
        right_words = word_tokenize(right_text)
        right_words.reverse()
        if len(left_words) < 6 or len(right_words) < 6:
            prediction = ':1.0'
        else:
            prediction = predict_probs(left_words[-5:], right_words[-5:])
        file.write(prediction + '\n')