36 KiB
36 KiB
import pandas as pd
import numpy as np
import regex as re
import csv
import torch
from torch import nn
from gensim.models import Word2Vec
from nltk.tokenize import word_tokenize
def clean_text(text):
text = text.lower().replace('-\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\n', '').replace('\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\n', ' ')
text = re.sub(r'\p{P}', '', text)
text = text.replace("'t", " not").replace("'s", " is").replace("'ll", " will").replace("'m", " am").replace("'ve", " have")
return text
train_data = pd.read_csv('train/in.tsv.xz', sep='\t', on_bad_lines='skip', header=None, quoting=csv.QUOTE_NONE)
train_labels = pd.read_csv('train/expected.tsv', sep='\t', on_bad_lines='skip', header=None, quoting=csv.QUOTE_NONE)
train_data = train_data[[6, 7]]
train_data = pd.concat([train_data, train_labels], axis=1)
class TrainCorpus:
def __init__(self, data):
self.data = data
def __iter__(self):
for _, row in self.data.iterrows():
text = str(row[6]) + str(row[0]) + str(row[7])
text = clean_text(text)
yield word_tokenize(text)
train_sentences = TrainCorpus(train_data.head(80000))
w2v_model = Word2Vec(vector_size=100, min_count=10)
w2v_model.build_vocab(corpus_iterable=train_sentences)
key_to_index = w2v_model.wv.key_to_index
index_to_key = w2v_model.wv.index_to_key
index_to_key.append('<unk>')
key_to_index['<unk>'] = len(index_to_key) - 1
vocab_size = len(index_to_key)
print(vocab_size)
81477
class TrainDataset(torch.utils.data.IterableDataset):
def __init__(self, data, index_to_key, key_to_index, reversed=False):
self.reversed = reversed
self.data = data
self.index_to_key = index_to_key
self.key_to_index = key_to_index
self.vocab_size = len(key_to_index)
def __iter__(self):
for _, row in self.data.iterrows():
text = str(row[6]) + str(row[0]) + str(row[7])
text = clean_text(text)
tokens = word_tokenize(text)
if self.reversed:
tokens = list(reversed(tokens))
for i in range(5, len(tokens), 1):
input_context = tokens[i-5:i]
target_context = tokens[i-4:i+1]
input_embed = [self.key_to_index[word] if word in self.key_to_index else self.key_to_index['<unk>'] for word in input_context]
target_embed = [self.key_to_index[word] if word in self.key_to_index else self.key_to_index['<unk>'] for word in target_context]
yield np.asarray(input_embed, dtype=np.int64), np.asarray(target_embed, dtype=np.int64)
class Model(nn.Module):
def __init__(self, embed_size, vocab_size):
super(Model, self).__init__()
self.embed_size = embed_size
self.vocab_size = vocab_size
self.lstm_size = 128
self.num_layers = 2
self.embed = nn.Embedding(num_embeddings=vocab_size, embedding_dim=self.embed_size)
self.lstm = nn.LSTM(input_size=self.embed_size, hidden_size=self.lstm_size, num_layers=self.num_layers, dropout=0.2)
self.fc = nn.Linear(self.lstm_size, vocab_size)
def forward(self, x, prev_state = None):
embed = self.embed(x)
output, state = self.lstm(embed, prev_state)
logits = self.fc(output)
probs = torch.softmax(logits, dim=1)
return logits, state
def init_state(self, sequence_length):
zeros = torch.zeros(self.num_layers, sequence_length, self.gru_size).to(device)
return (zeros, zeros)
from torch.utils.data import DataLoader
from torch.optim import Adam
def train(dataset, model, max_epochs, batch_size):
model.train()
dataloader = DataLoader(dataset, batch_size=batch_size)
criterion = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=0.001)
for epoch in range(max_epochs):
for batch, (x, y) in enumerate(dataloader):
optimizer.zero_grad()
x = x.to(device)
y = y.to(device)
y_pred, (state_h, state_c) = model(x)
loss = criterion(y_pred.transpose(1, 2), y)
loss.backward()
optimizer.step()
if batch % 100 == 0:
print(f'epoch: {epoch}, update in batch {batch}/???, loss: {loss.item()}')
torch.cuda.empty_cache()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
train_dataset_front = TrainDataset(train_data.head(8000), index_to_key, key_to_index, False)
model_front = Model(100, vocab_size).to(device)
train(train_dataset_front, model_front, 1, 64)
epoch: 0, update in batch 0/???, loss: 11.315739631652832 epoch: 0, update in batch 100/???, loss: 8.016324996948242 epoch: 0, update in batch 200/???, loss: 7.45602560043335 epoch: 0, update in batch 300/???, loss: 6.306332588195801 epoch: 0, update in batch 400/???, loss: 8.629552841186523 epoch: 0, update in batch 500/???, loss: 7.637443542480469 epoch: 0, update in batch 600/???, loss: 7.67318868637085 epoch: 0, update in batch 700/???, loss: 7.2209930419921875 epoch: 0, update in batch 800/???, loss: 7.739532470703125 epoch: 0, update in batch 900/???, loss: 7.219891548156738 epoch: 0, update in batch 1000/???, loss: 6.8804473876953125 epoch: 0, update in batch 1100/???, loss: 7.228173732757568 epoch: 0, update in batch 1200/???, loss: 6.513087272644043 epoch: 0, update in batch 1300/???, loss: 7.142991542816162 epoch: 0, update in batch 1400/???, loss: 7.711663246154785 epoch: 0, update in batch 1500/???, loss: 6.894327640533447 epoch: 0, update in batch 1600/???, loss: 7.723884582519531 epoch: 0, update in batch 1700/???, loss: 8.409640312194824 epoch: 0, update in batch 1800/???, loss: 6.570927619934082 epoch: 0, update in batch 1900/???, loss: 6.906421661376953 epoch: 0, update in batch 2000/???, loss: 7.197023868560791 epoch: 0, update in batch 2100/???, loss: 6.892503261566162 epoch: 0, update in batch 2200/???, loss: 7.109471321105957 epoch: 0, update in batch 2300/???, loss: 8.84702205657959 epoch: 0, update in batch 2400/???, loss: 7.394454002380371 epoch: 0, update in batch 2500/???, loss: 7.380859375 epoch: 0, update in batch 2600/???, loss: 6.635237693786621 epoch: 0, update in batch 2700/???, loss: 6.869620323181152 epoch: 0, update in batch 2800/???, loss: 6.656294822692871 epoch: 0, update in batch 2900/???, loss: 8.090291976928711 epoch: 0, update in batch 3000/???, loss: 7.012345314025879 epoch: 0, update in batch 3100/???, loss: 6.7099809646606445 epoch: 0, update in batch 3200/???, loss: 6.798626899719238 epoch: 0, update in batch 3300/???, loss: 6.510752201080322 epoch: 0, update in batch 3400/???, loss: 7.742552757263184 epoch: 0, update in batch 3500/???, loss: 7.3319292068481445 epoch: 0, update in batch 3600/???, loss: 8.022462844848633 epoch: 0, update in batch 3700/???, loss: 5.883602619171143 epoch: 0, update in batch 3800/???, loss: 6.235389232635498 epoch: 0, update in batch 3900/???, loss: 7.012289524078369 epoch: 0, update in batch 4000/???, loss: 7.005420684814453 epoch: 0, update in batch 4100/???, loss: 6.595402717590332 epoch: 0, update in batch 4200/???, loss: 6.7428154945373535 epoch: 0, update in batch 4300/???, loss: 6.358878135681152 epoch: 0, update in batch 4400/???, loss: 6.6188201904296875 epoch: 0, update in batch 4500/???, loss: 7.08281946182251 epoch: 0, update in batch 4600/???, loss: 5.705609321594238 epoch: 0, update in batch 4700/???, loss: 7.1878180503845215 epoch: 0, update in batch 4800/???, loss: 7.071160793304443 epoch: 0, update in batch 4900/???, loss: 6.768280029296875 epoch: 0, update in batch 5000/???, loss: 6.507267951965332 epoch: 0, update in batch 5100/???, loss: 6.6431379318237305 epoch: 0, update in batch 5200/???, loss: 6.719052314758301 epoch: 0, update in batch 5300/???, loss: 7.172060489654541 epoch: 0, update in batch 5400/???, loss: 5.98638916015625 epoch: 0, update in batch 5500/???, loss: 5.674165725708008 epoch: 0, update in batch 5600/???, loss: 5.612569808959961 epoch: 0, update in batch 5700/???, loss: 6.307109832763672 epoch: 0, update in batch 5800/???, loss: 5.382391452789307 epoch: 0, update in batch 5900/???, loss: 5.712988376617432 epoch: 0, update in batch 6000/???, loss: 6.371735572814941 epoch: 0, update in batch 6100/???, loss: 6.417542457580566 epoch: 0, update in batch 6200/???, loss: 7.14879846572876 epoch: 0, update in batch 6300/???, loss: 7.0701189041137695 epoch: 0, update in batch 6400/???, loss: 7.048495292663574 epoch: 0, update in batch 6500/???, loss: 7.3384833335876465 epoch: 0, update in batch 6600/???, loss: 6.561330318450928 epoch: 0, update in batch 6700/???, loss: 6.839573860168457 epoch: 0, update in batch 6800/???, loss: 6.5179548263549805 epoch: 0, update in batch 6900/???, loss: 7.246607303619385 epoch: 0, update in batch 7000/???, loss: 6.5699052810668945 epoch: 0, update in batch 7100/???, loss: 7.202715873718262 epoch: 0, update in batch 7200/???, loss: 6.1833648681640625 epoch: 0, update in batch 7300/???, loss: 5.977782249450684 epoch: 0, update in batch 7400/???, loss: 6.717446327209473 epoch: 0, update in batch 7500/???, loss: 6.574376583099365 epoch: 0, update in batch 7600/???, loss: 5.8418450355529785 epoch: 0, update in batch 7700/???, loss: 6.282655715942383 epoch: 0, update in batch 7800/???, loss: 6.065321922302246 epoch: 0, update in batch 7900/???, loss: 6.415077209472656 epoch: 0, update in batch 8000/???, loss: 6.482673645019531 epoch: 0, update in batch 8100/???, loss: 6.670407772064209 epoch: 0, update in batch 8200/???, loss: 6.799211025238037 epoch: 0, update in batch 8300/???, loss: 7.299313545227051 epoch: 0, update in batch 8400/???, loss: 7.42974328994751 epoch: 0, update in batch 8500/???, loss: 8.549559593200684 epoch: 0, update in batch 8600/???, loss: 6.794680118560791 epoch: 0, update in batch 8700/???, loss: 7.390380859375 epoch: 0, update in batch 8800/???, loss: 7.552660942077637 epoch: 0, update in batch 8900/???, loss: 6.663547515869141 epoch: 0, update in batch 9000/???, loss: 6.5236711502075195 epoch: 0, update in batch 9100/???, loss: 7.666424751281738 epoch: 0, update in batch 9200/???, loss: 6.479496955871582 epoch: 0, update in batch 9300/???, loss: 5.5056304931640625 epoch: 0, update in batch 9400/???, loss: 6.6904096603393555 epoch: 0, update in batch 9500/???, loss: 6.9318037033081055 epoch: 0, update in batch 9600/???, loss: 6.521365165710449 epoch: 0, update in batch 9700/???, loss: 6.376631736755371 epoch: 0, update in batch 9800/???, loss: 6.4104766845703125 epoch: 0, update in batch 9900/???, loss: 7.3995232582092285 epoch: 0, update in batch 10000/???, loss: 6.510337829589844 epoch: 0, update in batch 10100/???, loss: 6.2512407302856445 epoch: 0, update in batch 10200/???, loss: 6.048404216766357 epoch: 0, update in batch 10300/???, loss: 6.832150936126709 epoch: 0, update in batch 10400/???, loss: 6.7485456466674805 epoch: 0, update in batch 10500/???, loss: 5.385656833648682 epoch: 0, update in batch 10600/???, loss: 6.769070625305176 epoch: 0, update in batch 10700/???, loss: 6.857029914855957 epoch: 0, update in batch 10800/???, loss: 5.991332530975342 epoch: 0, update in batch 10900/???, loss: 6.5500006675720215 epoch: 0, update in batch 11000/???, loss: 6.951509952545166 epoch: 0, update in batch 11100/???, loss: 6.396986961364746 epoch: 0, update in batch 11200/???, loss: 6.639346122741699 epoch: 0, update in batch 11300/???, loss: 5.87351655960083 epoch: 0, update in batch 11400/???, loss: 5.996974945068359 epoch: 0, update in batch 11500/???, loss: 7.103158473968506 epoch: 0, update in batch 11600/???, loss: 6.429941654205322 epoch: 0, update in batch 11700/???, loss: 5.597273826599121 epoch: 0, update in batch 11800/???, loss: 7.112508296966553 epoch: 0, update in batch 11900/???, loss: 6.745194911956787 epoch: 0, update in batch 12000/???, loss: 7.47100305557251 epoch: 0, update in batch 12100/???, loss: 6.847914695739746 epoch: 0, update in batch 12200/???, loss: 6.876992702484131 epoch: 0, update in batch 12300/???, loss: 6.499053955078125 epoch: 0, update in batch 12400/???, loss: 7.196413993835449 epoch: 0, update in batch 12500/???, loss: 6.593430995941162 epoch: 0, update in batch 12600/???, loss: 6.368945121765137 epoch: 0, update in batch 12700/???, loss: 6.362246513366699 epoch: 0, update in batch 12800/???, loss: 7.209506034851074 epoch: 0, update in batch 12900/???, loss: 6.8092780113220215 epoch: 0, update in batch 13000/???, loss: 8.273663520812988 epoch: 0, update in batch 13100/???, loss: 7.061187744140625 epoch: 0, update in batch 13200/???, loss: 5.778809547424316 epoch: 0, update in batch 13300/???, loss: 5.650263786315918 epoch: 0, update in batch 13400/???, loss: 5.9032440185546875 epoch: 0, update in batch 13500/???, loss: 6.629636287689209 epoch: 0, update in batch 13600/???, loss: 6.577019691467285 epoch: 0, update in batch 13700/???, loss: 5.953114032745361 epoch: 0, update in batch 13800/???, loss: 6.630902290344238 epoch: 0, update in batch 13900/???, loss: 7.593966484069824 epoch: 0, update in batch 14000/???, loss: 6.636081695556641 epoch: 0, update in batch 14100/???, loss: 5.772985458374023 epoch: 0, update in batch 14200/???, loss: 5.907249450683594 epoch: 0, update in batch 14300/???, loss: 7.863391876220703 epoch: 0, update in batch 14400/???, loss: 7.275572776794434 epoch: 0, update in batch 14500/???, loss: 6.818984031677246 epoch: 0, update in batch 14600/???, loss: 6.0456342697143555 epoch: 0, update in batch 14700/???, loss: 6.281990051269531 epoch: 0, update in batch 14800/???, loss: 6.197850227355957 epoch: 0, update in batch 14900/???, loss: 5.851240634918213 epoch: 0, update in batch 15000/???, loss: 6.826748847961426 epoch: 0, update in batch 15100/???, loss: 7.2189483642578125 epoch: 0, update in batch 15200/???, loss: 6.609204292297363 epoch: 0, update in batch 15300/???, loss: 6.947709560394287 epoch: 0, update in batch 15400/???, loss: 6.604478359222412 epoch: 0, update in batch 15500/???, loss: 6.222006797790527 epoch: 0, update in batch 15600/???, loss: 6.515635013580322 epoch: 0, update in batch 15700/???, loss: 6.40108585357666 epoch: 0, update in batch 15800/???, loss: 6.36106014251709 epoch: 0, update in batch 15900/???, loss: 6.533608436584473 epoch: 0, update in batch 16000/???, loss: 6.662516117095947 epoch: 0, update in batch 16100/???, loss: 7.284195899963379 epoch: 0, update in batch 16200/???, loss: 6.6524176597595215 epoch: 0, update in batch 16300/???, loss: 6.430756568908691 epoch: 0, update in batch 16400/???, loss: 7.515387058258057 epoch: 0, update in batch 16500/???, loss: 6.938241481781006 epoch: 0, update in batch 16600/???, loss: 5.860864162445068 epoch: 0, update in batch 16700/???, loss: 6.451329231262207 epoch: 0, update in batch 16800/???, loss: 6.5510663986206055 epoch: 0, update in batch 16900/???, loss: 7.3591437339782715 epoch: 0, update in batch 17000/???, loss: 6.158746719360352 epoch: 0, update in batch 17100/???, loss: 7.202520847320557 epoch: 0, update in batch 17200/???, loss: 6.80673885345459 epoch: 0, update in batch 17300/???, loss: 6.698304653167725 epoch: 0, update in batch 17400/???, loss: 5.743161201477051 epoch: 0, update in batch 17500/???, loss: 6.518529415130615 epoch: 0, update in batch 17600/???, loss: 6.021708011627197 epoch: 0, update in batch 17700/???, loss: 6.354712963104248 epoch: 0, update in batch 17800/???, loss: 6.323357582092285 epoch: 0, update in batch 17900/???, loss: 6.61548376083374 epoch: 0, update in batch 18000/???, loss: 6.600308895111084 epoch: 0, update in batch 18100/???, loss: 6.794068336486816 epoch: 0, update in batch 18200/???, loss: 7.487390041351318 epoch: 0, update in batch 18300/???, loss: 5.973461627960205 epoch: 0, update in batch 18400/???, loss: 6.891515254974365 epoch: 0, update in batch 18500/???, loss: 5.897144317626953 epoch: 0, update in batch 18600/???, loss: 6.6016364097595215 epoch: 0, update in batch 18700/???, loss: 6.948650360107422 epoch: 0, update in batch 18800/???, loss: 7.221627235412598 epoch: 0, update in batch 18900/???, loss: 6.817994117736816 epoch: 0, update in batch 19000/???, loss: 5.730655193328857 epoch: 0, update in batch 19100/???, loss: 6.236818790435791 epoch: 0, update in batch 19200/???, loss: 7.178666114807129 epoch: 0, update in batch 19300/???, loss: 6.77465295791626 epoch: 0, update in batch 19400/???, loss: 6.996792793273926 epoch: 0, update in batch 19500/???, loss: 6.80951452255249 epoch: 0, update in batch 19600/???, loss: 7.1757965087890625 epoch: 0, update in batch 19700/???, loss: 8.400952339172363 epoch: 0, update in batch 19800/???, loss: 7.1904473304748535 epoch: 0, update in batch 19900/???, loss: 6.339241981506348 epoch: 0, update in batch 20000/???, loss: 7.078637599945068 epoch: 0, update in batch 20100/???, loss: 5.015235900878906 epoch: 0, update in batch 20200/???, loss: 6.763777732849121 epoch: 0, update in batch 20300/???, loss: 6.543915748596191 epoch: 0, update in batch 20400/???, loss: 6.027902603149414 epoch: 0, update in batch 20500/???, loss: 6.710694789886475 epoch: 0, update in batch 20600/???, loss: 6.800978660583496 epoch: 0, update in batch 20700/???, loss: 6.371827125549316 epoch: 0, update in batch 20800/???, loss: 5.952463626861572 epoch: 0, update in batch 20900/???, loss: 6.317960739135742 epoch: 0, update in batch 21000/???, loss: 7.178386688232422 epoch: 0, update in batch 21100/???, loss: 6.887454986572266 epoch: 0, update in batch 21200/???, loss: 6.468400478363037 epoch: 0, update in batch 21300/???, loss: 7.8383684158325195 epoch: 0, update in batch 21400/???, loss: 5.850740909576416 epoch: 0, update in batch 21500/???, loss: 6.065464973449707 epoch: 0, update in batch 21600/???, loss: 7.537625312805176 epoch: 0, update in batch 21700/???, loss: 6.095994472503662 epoch: 0, update in batch 21800/???, loss: 6.342766761779785 epoch: 0, update in batch 21900/???, loss: 5.810301780700684 epoch: 0, update in batch 22000/???, loss: 6.447206974029541 epoch: 0, update in batch 22100/???, loss: 7.0662946701049805 epoch: 0, update in batch 22200/???, loss: 6.535088539123535 epoch: 0, update in batch 22300/???, loss: 7.017588138580322 epoch: 0, update in batch 22400/???, loss: 5.067782402038574 epoch: 0, update in batch 22500/???, loss: 6.493170738220215 epoch: 0, update in batch 22600/???, loss: 5.642627716064453 epoch: 0, update in batch 22700/???, loss: 7.200662136077881 epoch: 0, update in batch 22800/???, loss: 6.137134075164795 epoch: 0, update in batch 22900/???, loss: 6.367280006408691 epoch: 0, update in batch 23000/???, loss: 7.458652496337891 epoch: 0, update in batch 23100/???, loss: 6.515708923339844 epoch: 0, update in batch 23200/???, loss: 7.526422023773193 epoch: 0, update in batch 23300/???, loss: 6.653852939605713 epoch: 0, update in batch 23400/???, loss: 6.737251281738281 epoch: 0, update in batch 23500/???, loss: 6.493605136871338 epoch: 0, update in batch 23600/???, loss: 6.132809638977051 epoch: 0, update in batch 23700/???, loss: 6.406940460205078 epoch: 0, update in batch 23800/???, loss: 6.84005880355835 epoch: 0, update in batch 23900/???, loss: 6.830739498138428 epoch: 0, update in batch 24000/???, loss: 5.862464427947998 epoch: 0, update in batch 24100/???, loss: 6.382696628570557 epoch: 0, update in batch 24200/???, loss: 5.722895622253418 epoch: 0, update in batch 24300/???, loss: 6.697083473205566 epoch: 0, update in batch 24400/???, loss: 6.56771183013916 epoch: 0, update in batch 24500/???, loss: 7.566462516784668 epoch: 0, update in batch 24600/???, loss: 6.217026710510254 epoch: 0, update in batch 24700/???, loss: 7.164259433746338 epoch: 0, update in batch 24800/???, loss: 6.460946083068848 epoch: 0, update in batch 24900/???, loss: 6.333778381347656 epoch: 0, update in batch 25000/???, loss: 6.522342681884766 epoch: 0, update in batch 25100/???, loss: 6.270648002624512 epoch: 0, update in batch 25200/???, loss: 7.118265628814697 epoch: 0, update in batch 25300/???, loss: 5.8695197105407715 epoch: 0, update in batch 25400/???, loss: 5.92995023727417 epoch: 0, update in batch 25500/???, loss: 6.202570915222168 epoch: 0, update in batch 25600/???, loss: 6.4268975257873535 epoch: 0, update in batch 25700/???, loss: 6.710567474365234 epoch: 0, update in batch 25800/???, loss: 6.130914688110352 epoch: 0, update in batch 25900/???, loss: 6.082686424255371 epoch: 0, update in batch 26000/???, loss: 6.111697196960449 epoch: 0, update in batch 26100/???, loss: 7.320557594299316 epoch: 0, update in batch 26200/???, loss: 6.227985858917236 epoch: 0, update in batch 26300/???, loss: 6.204974174499512 epoch: 0, update in batch 26400/???, loss: 6.658400058746338 epoch: 0, update in batch 26500/???, loss: 5.911742687225342 epoch: 0, update in batch 26600/???, loss: 6.891500949859619 epoch: 0, update in batch 26700/???, loss: 5.763737201690674 epoch: 0, update in batch 26800/???, loss: 5.757307529449463 epoch: 0, update in batch 26900/???, loss: 6.076601982116699 epoch: 0, update in batch 27000/???, loss: 6.193032264709473 epoch: 0, update in batch 27100/???, loss: 6.120661735534668 epoch: 0, update in batch 27200/???, loss: 6.5425519943237305 epoch: 0, update in batch 27300/???, loss: 6.511394500732422 epoch: 0, update in batch 27400/???, loss: 7.127263069152832 epoch: 0, update in batch 27500/???, loss: 6.134243488311768 epoch: 0, update in batch 27600/???, loss: 6.5747809410095215 epoch: 0, update in batch 27700/???, loss: 6.351634979248047 epoch: 0, update in batch 27800/???, loss: 5.589611530303955 epoch: 0, update in batch 27900/???, loss: 6.916817665100098 epoch: 0, update in batch 28000/???, loss: 5.711864948272705 epoch: 0, update in batch 28100/???, loss: 6.921398162841797 epoch: 0, update in batch 28200/???, loss: 6.785823822021484 epoch: 0, update in batch 28300/???, loss: 6.007838249206543 epoch: 0, update in batch 28400/???, loss: 6.338862419128418 epoch: 0, update in batch 28500/???, loss: 6.9078168869018555 epoch: 0, update in batch 28600/???, loss: 6.710842132568359 epoch: 0, update in batch 28700/???, loss: 6.592329502105713 epoch: 0, update in batch 28800/???, loss: 6.184128761291504 epoch: 0, update in batch 28900/???, loss: 6.209361553192139 epoch: 0, update in batch 29000/???, loss: 7.067984104156494 epoch: 0, update in batch 29100/???, loss: 6.479236602783203 epoch: 0, update in batch 29200/???, loss: 6.413198947906494 epoch: 0, update in batch 29300/???, loss: 6.638579368591309 epoch: 0, update in batch 29400/???, loss: 5.938233375549316 epoch: 0, update in batch 29500/???, loss: 6.8490891456604 epoch: 0, update in batch 29600/???, loss: 6.111110210418701 epoch: 0, update in batch 29700/???, loss: 6.959462642669678 epoch: 0, update in batch 29800/???, loss: 6.964720726013184 epoch: 0, update in batch 29900/???, loss: 6.2007527351379395 epoch: 0, update in batch 30000/???, loss: 6.803907871246338 epoch: 0, update in batch 30100/???, loss: 5.665301322937012 epoch: 0, update in batch 30200/???, loss: 6.913702487945557 epoch: 0, update in batch 30300/???, loss: 6.824265956878662 epoch: 0, update in batch 30400/???, loss: 6.131905555725098 epoch: 0, update in batch 30500/???, loss: 5.799595832824707 epoch: 0, update in batch 30600/???, loss: 6.846949100494385 epoch: 0, update in batch 30700/???, loss: 6.481771945953369 epoch: 0, update in batch 30800/???, loss: 6.5581254959106445 epoch: 0, update in batch 30900/???, loss: 6.111696720123291 epoch: 0, update in batch 31000/???, loss: 4.8547563552856445 epoch: 0, update in batch 31100/???, loss: 6.5503740310668945 epoch: 0, update in batch 31200/???, loss: 6.212404251098633 epoch: 0, update in batch 31300/???, loss: 5.761624336242676 epoch: 0, update in batch 31400/???, loss: 7.043508052825928 epoch: 0, update in batch 31500/???, loss: 8.301980018615723 epoch: 0, update in batch 31600/???, loss: 5.655745506286621 epoch: 0, update in batch 31700/???, loss: 7.116888999938965 epoch: 0, update in batch 31800/???, loss: 6.237078666687012 epoch: 0, update in batch 31900/???, loss: 6.990937232971191 epoch: 0, update in batch 32000/???, loss: 6.327075958251953 epoch: 0, update in batch 32100/???, loss: 6.831456184387207 epoch: 0, update in batch 32200/???, loss: 6.511493682861328 epoch: 0, update in batch 32300/???, loss: 6.719797611236572 epoch: 0, update in batch 32400/???, loss: 6.46258544921875 epoch: 0, update in batch 32500/???, loss: 7.349535942077637 epoch: 0, update in batch 32600/???, loss: 5.773186683654785 epoch: 0, update in batch 32700/???, loss: 6.072037696838379 epoch: 0, update in batch 32800/???, loss: 7.044579982757568 epoch: 0, update in batch 32900/???, loss: 6.290024757385254 epoch: 0, update in batch 33000/???, loss: 7.101686000823975 epoch: 0, update in batch 33100/???, loss: 6.590539455413818 epoch: 0, update in batch 33200/???, loss: 6.944089412689209 epoch: 0, update in batch 33300/???, loss: 6.6709442138671875 epoch: 0, update in batch 33400/???, loss: 7.119935035705566 epoch: 0, update in batch 33500/???, loss: 6.845646858215332 epoch: 0, update in batch 33600/???, loss: 6.941410064697266 epoch: 0, update in batch 33700/???, loss: 6.341822624206543 epoch: 0, update in batch 33800/???, loss: 6.98660945892334 epoch: 0, update in batch 33900/???, loss: 7.544371128082275 epoch: 0, update in batch 34000/???, loss: 6.844598293304443 epoch: 0, update in batch 34100/???, loss: 6.958268642425537 epoch: 0, update in batch 34200/???, loss: 6.6372880935668945
def predict_probs(left_tokens, right_tokens):
model_front.eval()
x_left = torch.tensor([[train_dataset_front.key_to_index[w] if w in key_to_index else train_dataset_front.key_to_index['<unk>'] for w in left_tokens]]).to(device)
x_right = torch.tensor([[train_dataset_front.key_to_index[w] if w in key_to_index else train_dataset_front.key_to_index['<unk>'] for w in right_tokens]]).to(device)
y_pred_left, (state_h_left, state_c_left) = model_front(x_left)
y_pred_right, (state_h_right, state_c_right) = model_back(x_right)
last_word_logits_left = y_pred_left[0][-1]
last_word_logits_right = y_pred_right[0][-1]
probs_left = torch.nn.functional.softmax(last_word_logits_left, dim=0).detach().cpu().numpy()
probs_right = torch.nn.functional.softmax(last_word_logits_right, dim=0).detach().cpu().numpy()
probs = [np.mean(k) for k in zip(probs_left, probs_right)]
top_words = []
for index in range(len(probs)):
if len(top_words) < 30:
top_words.append((probs[index], [index]))
else:
worst_word = None
for word in top_words:
if not worst_word:
worst_word = word
else:
if word[0] < worst_word[0]:
worst_word = word
if worst_word[0] < probs[index] and index != len(probs) - 1:
top_words.remove(worst_word)
top_words.append((probs[index], [index]))
prediction = ''
sum_prob = 0.0
for word in top_words:
sum_prob += word[0]
word_index = word[0]
word_text = index_to_key[word[1][0]]
prediction += f'{word_text}:{word_index} '
prediction += f':{1 - sum_prob}'
return prediction
dev_data = pd.read_csv('dev-0/in.tsv.xz', sep='\t', on_bad_lines='skip', header=None, quoting=csv.QUOTE_NONE)
test_data = pd.read_csv('test-A/in.tsv.xz', sep='\t', on_bad_lines='skip', header=None, quoting=csv.QUOTE_NONE)
with open('dev-0/out.tsv', 'w') as file:
for index, row in dev_data.iterrows():
left_text = clean_text(str(row[6]))
right_text = clean_text(str(row[7]))
left_words = word_tokenize(left_text)
right_words = word_tokenize(right_text)
right_words.reverse()
if len(left_words) < 6 or len(right_words) < 6:
prediction = ':1.0'
else:
prediction = predict_probs(left_words[-5:], right_words[-5:])
file.write(prediction + '\n')
with open('test-A/out.tsv', 'w') as file:
for index, row in test_data.iterrows():
left_text = clean_text(str(row[6]))
right_text = clean_text(str(row[7]))
left_words = word_tokenize(left_text)
right_words = word_tokenize(right_text)
right_words.reverse()
if len(left_words) < 6 or len(right_words) < 6:
prediction = ':1.0'
else:
prediction = predict_probs(left_words[-5:], right_words[-5:])
file.write(prediction + '\n')