In [1]:
import pandas as pd
import numpy as np
import regex as re
import csv
import torch
from torch import nn
from gensim.models import Word2Vec
from nltk.tokenize import word_tokenize

In [2]:
torch.cuda.empty_cache()
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [3]:
def clean_text(text):
    text = text.lower().replace('-\\\\\\\\n', '').replace('\\\\\\\\n', ' ')
    text = re.sub(r'\p{P}', '', text)
    text = text.replace("'t", " not").replace("'s", " is").replace("'ll", " will").replace("'m", " am").replace("'ve", " have")

    return text

In [4]:
train_data = pd.read_csv('train/in.tsv.xz', sep='\t', on_bad_lines='skip', header=None, quoting=csv.QUOTE_NONE)
train_labels = pd.read_csv('train/expected.tsv', sep='\t', on_bad_lines='skip', header=None, quoting=csv.QUOTE_NONE)

train_data = train_data[[6, 7]]
train_data = pd.concat([train_data, train_labels], axis=1)

In [5]:
class TrainCorpus:
    def __init__(self, data):
        self.data = data
        
    def __iter__(self):
        for _, row in self.data.iterrows():
            text = str(row[6]) + str(row[0]) + str(row[7])
            text = clean_text(text)
            yield word_tokenize(text)

In [6]:
train_sentences = TrainCorpus(train_data.head(80000))
w2v_model = Word2Vec(vector_size=100, min_count=10)

In [7]:
w2v_model.build_vocab(corpus_iterable=train_sentences)

key_to_index = w2v_model.wv.key_to_index
index_to_key = w2v_model.wv.index_to_key

index_to_key.append('<unk>')
key_to_index['<unk>'] = len(index_to_key) - 1

vocab_size = len(index_to_key)
print(vocab_size)

81477


In [8]:
class TrainDataset(torch.utils.data.IterableDataset):
    def __init__(self, data, index_to_key, key_to_index, reversed=False):
        self.reversed = reversed
        self.data = data
        self.index_to_key = index_to_key
        self.key_to_index = key_to_index
        self.vocab_size = len(key_to_index)

    def __iter__(self):
        for _, row in self.data.iterrows():
            text = str(row[6]) + str(row[0]) + str(row[7])
            text = clean_text(text)
            tokens = word_tokenize(text)
            if self.reversed:
                tokens = list(reversed(tokens))
            for i in range(5, len(tokens), 1):
                input_context = tokens[i-5:i]
                target_context = tokens[i-4:i+1]
                #gap_word = tokens[i]
            
                input_embed = [self.key_to_index[word] if word in self.key_to_index else self.key_to_index['<unk>'] for word in input_context]
                target_embed = [self.key_to_index[word] if word in self.key_to_index else self.key_to_index['<unk>'] for word in target_context]
                #word_index = self.key_to_index[gap_word] if gap_word in self.key_to_index else self.key_to_index['<unk>']
                #word_embed = np.concatenate([np.zeros(word_index), np.ones(1), np.zeros(vocab_size - word_index - 1)])
                
                yield np.asarray(input_embed, dtype=np.int64), np.asarray(target_embed, dtype=np.int64)

In [9]:
class Model(nn.Module):
    def __init__(self, embed_size, vocab_size):
        super(Model, self).__init__()
        self.embed_size = embed_size
        self.vocab_size = vocab_size
        self.lstm_size = 128
        self.num_layers = 2
        
        self.embed = nn.Embedding(num_embeddings=vocab_size, embedding_dim=self.embed_size)
        self.lstm = nn.LSTM(input_size=self.embed_size, hidden_size=self.lstm_size, num_layers=self.num_layers, dropout=0.2)
        self.fc = nn.Linear(self.lstm_size, vocab_size)

    def forward(self, x, prev_state = None):
        embed = self.embed(x)
        output, state = self.lstm(embed, prev_state)
        logits = self.fc(output)
        probs = torch.softmax(logits, dim=1)
        return logits, state

    def init_state(self, sequence_length):
        zeros = torch.zeros(self.num_layers, sequence_length, self.gru_size).to(device)
        return (zeros, zeros)

In [10]:
from torch.utils.data import DataLoader
from torch.optim import Adam

def train(dataset, model, max_epochs, batch_size):
    model.train()

    dataloader = DataLoader(dataset, batch_size=batch_size)
    criterion = nn.CrossEntropyLoss()
    optimizer = Adam(model.parameters(), lr=0.001)

    for epoch in range(max_epochs):
        for batch, (x, y) in enumerate(dataloader):
            optimizer.zero_grad()
            
            x = x.to(device)
            y = y.to(device)
            
            y_pred, (state_h, state_c) = model(x)
            loss = criterion(y_pred.transpose(1, 2), y)

            loss.backward()
            optimizer.step()
            
            if batch % 1000 == 0:
                print(f'epoch: {epoch}, update in batch {batch}/???, loss: {loss.item()}')

In [11]:
train_dataset_front = TrainDataset(train_data.head(80000), index_to_key, key_to_index, False)
train_dataset_back = TrainDataset(train_data.tail(80000), index_to_key, key_to_index, True)

In [12]:
model_front = Model(100, vocab_size).to(device)
model_back = Model(100, vocab_size).to(device)

In [13]:
train(train_dataset_front, model_front, 1, 64)

epoch: 0, update in batch 0/???, loss: 11.314821243286133
epoch: 0, update in batch 1000/???, loss: 6.876476287841797
epoch: 0, update in batch 2000/???, loss: 7.133523464202881
epoch: 0, update in batch 3000/???, loss: 6.979971885681152
epoch: 0, update in batch 4000/???, loss: 7.018368721008301
epoch: 0, update in batch 5000/???, loss: 6.494096279144287
epoch: 0, update in batch 6000/???, loss: 6.448479652404785
epoch: 0, update in batch 7000/???, loss: 6.526387691497803
epoch: 0, update in batch 8000/???, loss: 6.536323547363281
epoch: 0, update in batch 9000/???, loss: 6.4919538497924805
epoch: 0, update in batch 10000/???, loss: 6.435188293457031
epoch: 0, update in batch 11000/???, loss: 6.934823513031006
epoch: 0, update in batch 12000/???, loss: 7.410381317138672
epoch: 0, update in batch 13000/???, loss: 8.227864265441895
epoch: 0, update in batch 14000/???, loss: 6.7139105796813965
epoch: 0, update in batch 15000/???, loss: 6.82781457901001
epoch: 0, update in batch 16000/???

epoch: 0, update in batch 134000/???, loss: 6.055495262145996
epoch: 0, update in batch 135000/???, loss: 6.363399982452393
epoch: 0, update in batch 136000/???, loss: 7.45733642578125
epoch: 0, update in batch 137000/???, loss: 6.960203647613525
epoch: 0, update in batch 138000/???, loss: 6.986503601074219
epoch: 0, update in batch 139000/???, loss: 5.7938127517700195
epoch: 0, update in batch 140000/???, loss: 5.559916019439697
epoch: 0, update in batch 141000/???, loss: 5.551616668701172
epoch: 0, update in batch 142000/???, loss: 5.386819839477539
epoch: 0, update in batch 143000/???, loss: 6.826618194580078
epoch: 0, update in batch 144000/???, loss: 6.106345176696777
epoch: 0, update in batch 145000/???, loss: 6.812024116516113
epoch: 0, update in batch 146000/???, loss: 6.347486972808838
epoch: 0, update in batch 147000/???, loss: 6.20189094543457
epoch: 0, update in batch 148000/???, loss: 5.5717034339904785
epoch: 0, update in batch 149000/???, loss: 6.884232521057129
epoch: 0

epoch: 0, update in batch 267000/???, loss: 6.188825607299805
epoch: 0, update in batch 268000/???, loss: 6.518698215484619
epoch: 0, update in batch 269000/???, loss: 5.663434028625488
epoch: 0, update in batch 270000/???, loss: 5.978742599487305
epoch: 0, update in batch 271000/???, loss: 6.217379093170166
epoch: 0, update in batch 272000/???, loss: 5.426600933074951
epoch: 0, update in batch 273000/???, loss: 6.7220964431762695
epoch: 0, update in batch 274000/???, loss: 4.276306629180908
epoch: 0, update in batch 275000/???, loss: 5.420112609863281
epoch: 0, update in batch 276000/???, loss: 5.934456825256348
epoch: 0, update in batch 277000/???, loss: 7.186459541320801
epoch: 0, update in batch 278000/???, loss: 6.126835823059082
epoch: 0, update in batch 279000/???, loss: 5.727339267730713
epoch: 0, update in batch 280000/???, loss: 5.725864410400391
epoch: 0, update in batch 281000/???, loss: 5.47005033493042
epoch: 0, update in batch 282000/???, loss: 6.217499732971191
epoch: 0

In [14]:
train(train_dataset_back, model_back, 1, 64)

epoch: 0, update in batch 0/???, loss: 11.3253755569458
epoch: 0, update in batch 1000/???, loss: 5.709358215332031
epoch: 0, update in batch 2000/???, loss: 7.989391326904297
epoch: 0, update in batch 3000/???, loss: 6.578714847564697
epoch: 0, update in batch 4000/???, loss: 7.051873207092285
epoch: 0, update in batch 5000/???, loss: 6.85653018951416
epoch: 0, update in batch 6000/???, loss: 6.812790870666504
epoch: 0, update in batch 7000/???, loss: 6.9604010581970215
epoch: 0, update in batch 8000/???, loss: 6.798591613769531
epoch: 0, update in batch 9000/???, loss: 6.415241241455078
epoch: 0, update in batch 10000/???, loss: 6.6636223793029785
epoch: 0, update in batch 11000/???, loss: 6.593747138977051
epoch: 0, update in batch 12000/???, loss: 6.914702415466309
epoch: 0, update in batch 13000/???, loss: 5.542675971984863
epoch: 0, update in batch 14000/???, loss: 6.5461883544921875
epoch: 0, update in batch 15000/???, loss: 7.507067680358887
epoch: 0, update in batch 16000/???,

epoch: 0, update in batch 134000/???, loss: 5.503821849822998
epoch: 0, update in batch 135000/???, loss: 6.428965091705322
epoch: 0, update in batch 136000/???, loss: 5.102448463439941
epoch: 0, update in batch 137000/???, loss: 6.239314556121826
epoch: 0, update in batch 138000/???, loss: 6.028595447540283
epoch: 0, update in batch 139000/???, loss: 6.407244682312012
epoch: 0, update in batch 140000/???, loss: 5.597055912017822
epoch: 0, update in batch 141000/???, loss: 5.823704719543457
epoch: 0, update in batch 142000/???, loss: 6.665535926818848
epoch: 0, update in batch 143000/???, loss: 5.5736894607543945
epoch: 0, update in batch 144000/???, loss: 6.723180294036865
epoch: 0, update in batch 145000/???, loss: 6.378345489501953
epoch: 0, update in batch 146000/???, loss: 5.6936845779418945
epoch: 0, update in batch 147000/???, loss: 5.761658668518066
epoch: 0, update in batch 148000/???, loss: 5.580254077911377
epoch: 0, update in batch 149000/???, loss: 5.733176231384277
epoch:

epoch: 0, update in batch 267000/???, loss: 6.055562973022461
epoch: 0, update in batch 268000/???, loss: 5.796906471252441
epoch: 0, update in batch 269000/???, loss: 5.654962539672852
epoch: 0, update in batch 270000/???, loss: 6.574362277984619
epoch: 0, update in batch 271000/???, loss: 6.256768226623535
epoch: 0, update in batch 272000/???, loss: 6.8345208168029785
epoch: 0, update in batch 273000/???, loss: 6.066469669342041
epoch: 0, update in batch 274000/???, loss: 6.625809669494629
epoch: 0, update in batch 275000/???, loss: 4.762896537780762
epoch: 0, update in batch 276000/???, loss: 6.019833564758301
epoch: 0, update in batch 277000/???, loss: 6.227939605712891
epoch: 0, update in batch 278000/???, loss: 7.046879768371582
epoch: 0, update in batch 279000/???, loss: 6.068551540374756
epoch: 0, update in batch 280000/???, loss: 6.454771995544434
epoch: 0, update in batch 281000/???, loss: 3.9379985332489014
epoch: 0, update in batch 282000/???, loss: 5.615240097045898
epoch:

In [30]:
def predict_probs(left_tokens, right_tokens):
    model_front.eval()
    model_back.eval()

    x_left = torch.tensor([[train_dataset_front.key_to_index[w] if w in key_to_index else train_dataset_front.key_to_index['<unk>'] for w in left_tokens]]).to(device)
    x_right = torch.tensor([[train_dataset_front.key_to_index[w] if w in key_to_index else train_dataset_front.key_to_index['<unk>'] for w in right_tokens]]).to(device)
    y_pred_left, (state_h_left, state_c_left) = model_front(x_left)
    y_pred_right, (state_h_right, state_c_right) = model_back(x_right)

    last_word_logits_left = y_pred_left[0][-1]
    last_word_logits_right = y_pred_right[0][-1]
    probs_left = torch.nn.functional.softmax(last_word_logits_left, dim=0).detach().cpu().numpy()
    probs_right = torch.nn.functional.softmax(last_word_logits_right, dim=0).detach().cpu().numpy()
    
    probs = [np.mean(k) for k in zip(probs_left, probs_right)]
    
    top_words = []
    for index in range(len(probs)):
        if len(top_words) < 30:
            top_words.append((probs[index], [index]))
        else:
            worst_word = None
            for word in top_words:
                if not worst_word:
                    worst_word = word
                else:
                    if word[0] < worst_word[0]:
                        worst_word = word
            if worst_word[0] < probs[index] and index != len(probs) - 1:
                top_words.remove(worst_word)
                top_words.append((probs[index], [index]))
                
    prediction = ''
    sum_prob = 0.0
    for word in top_words:
        sum_prob += word[0]
        word_index = word[0]
        word_text = index_to_key[word[1][0]]
        prediction += f'{word_text}:{word_index} '
    prediction += f':{1 - sum_prob}'
    
    return prediction

In [16]:
dev_data = pd.read_csv('dev-0/in.tsv.xz', sep='\t', on_bad_lines='skip', header=None, quoting=csv.QUOTE_NONE)
test_data = pd.read_csv('test-A/in.tsv.xz', sep='\t', on_bad_lines='skip', header=None, quoting=csv.QUOTE_NONE)

In [39]:
with open('dev-0/out.tsv', 'w') as file:
    for index, row in dev_data.iterrows():
        left_text = clean_text(str(row[6]))
        right_text = clean_text(str(row[7]))
        left_words = word_tokenize(left_text)
        right_words = word_tokenize(right_text)
        right_words.reverse()
        if len(left_words) < 6 or len(right_words) < 6:
            prediction = ':1.0'
        else:
            prediction = predict_probs(left_words[-5:], right_words[-5:])
        file.write(prediction + '\n')

In [41]:
with open('test-A/out.tsv', 'w') as file:
    for index, row in test_data.iterrows():
        left_text = clean_text(str(row[6]))
        right_text = clean_text(str(row[7]))
        left_words = word_tokenize(left_text)
        right_words = word_tokenize(right_text)
        right_words.reverse()
        if len(left_words) < 6 or len(right_words) < 6:
            prediction = ':1.0'
        else:
            prediction = predict_probs(left_words[-5:], right_words[-5:])
        file.write(prediction + '\n')