In [1]:
import torch
from torch import nn

torch.cuda.empty_cache()

In [2]:
import pandas as pd
import regex as re
import csv

def clean_text(text):
    text = text.lower().replace('-\\\\n', '').replace('\\\\n', ' ')
    text = re.sub(r'\p{P}', '', text)
    text = text.replace("'t", " not").replace("'s", " is").replace("'ll", " will").replace("'m", " am").replace("'ve", " have")

    return text

In [3]:
train_data = pd.read_csv('train/in.tsv.xz', sep='\t', error_bad_lines=False, warn_bad_lines=False, header=None, quoting=csv.QUOTE_NONE)
train_labels = pd.read_csv('train/expected.tsv', sep='\t', error_bad_lines=False, warn_bad_lines=False, header=None, quoting=csv.QUOTE_NONE)

train_data = train_data[[6, 7]]
train_data = pd.concat([train_data, train_labels], axis=1)

train_data['text'] = train_data[6] + train_data[0] + train_data[7]
train_data = train_data[['text']]

with open('processed_train.txt', 'w', encoding='utf-8') as file:
    for _, row in train_data.iterrows():
        text = clean_text(str(row['text']))
        file.write(text + '\n')

In [4]:
vocab_size = 40000
embed_size = 300
hidden_size = 128

class SimpleTrigramNeuralLanguageModel(nn.Module):
    def __init__(self, vocabulary_size, embedding_size, hidden_size):
        super(SimpleTrigramNeuralLanguageModel, self).__init__()
        self.embedding = nn.Embedding(vocabulary_size * 2, embedding_size)
        self.linear1 = nn.Linear(embedding_size, hidden_size)
        self.linear2 = nn.Linear(hidden_size, vocabulary_size * 2)

    def forward(self, x):
        x = self.embedding(x)
        x = self.linear1(x)
        x = self.linear2(x)
        x = torch.softmax(x, dim=1)
        return x

In [5]:
import regex as re
from itertools import islice, chain
from torchtext.vocab import build_vocab_from_iterator
from torch.utils.data import IterableDataset

def get_words_from_line(line):
    line = line.rstrip()
    yield '<s>'
    for m in re.finditer(r'[\p{L}0-9\*]+|\p{P}+', line):
        yield m.group(0).lower()
    yield '</s>'

def get_word_lines_from_file(file_name):
    with open(file_name, 'r', encoding='utf-8') as fh:
        for line in fh:
            yield get_words_from_line(line)
            
def look_ahead_iterator(gen):
    prev_1 = None
    prev_2 = None
    for item in gen:
        if prev_1 and prev_2:
            yield (prev_2 + prev_1, item)
        prev_2 = prev_1
        prev_1 = item

In [6]:
class Trigrams(IterableDataset):
    def __init__(self, text_file, vocabulary_size):
        self.vocab = build_vocab_from_iterator(
            get_word_lines_from_file(text_file),
            max_tokens = vocabulary_size,
            specials = ['<unk>']
        )
        self.vocab.set_default_index(self.vocab['<unk>'])
        self.vocabulary_size = vocabulary_size
        self.text_file = text_file

    def __iter__(self):
        return look_ahead_iterator((self.vocab[t] for t in chain.from_iterable(get_word_lines_from_file(self.text_file))))

In [7]:
from torch.utils.data import DataLoader

device = 'cuda' if torch.cuda.is_available() else 'cpu'
train_dataset = Trigrams('processed_train.txt', vocab_size)
model = SimpleTrigramNeuralLanguageModel(vocab_size, embed_size, hidden_size).to(device)
data = DataLoader(train_dataset, batch_size=800)
optimizer = torch.optim.Adam(model.parameters())
criterion = torch.nn.NLLLoss()

In [8]:
step = 0

for epoch in range(2):
    model.train()
    for x, y in data:
        x = x.to(device)
        y = y.to(device)
        optimizer.zero_grad()
        outputs = model(x)
        loss = criterion(torch.log(outputs), y)
        if step % 100 == 0:
            print(step, loss)
        step += 1
        loss.backward()
        optimizer.step()

0 tensor(11.3293, device='cuda:0', grad_fn=<NllLossBackward0>)
100 tensor(8.9417, device='cuda:0', grad_fn=<NllLossBackward0>)
200 tensor(7.0454, device='cuda:0', grad_fn=<NllLossBackward0>)
300 tensor(6.8511, device='cuda:0', grad_fn=<NllLossBackward0>)
400 tensor(6.8680, device='cuda:0', grad_fn=<NllLossBackward0>)
500 tensor(6.8153, device='cuda:0', grad_fn=<NllLossBackward0>)
600 tensor(6.5640, device='cuda:0', grad_fn=<NllLossBackward0>)
700 tensor(6.8175, device='cuda:0', grad_fn=<NllLossBackward0>)
800 tensor(6.6864, device='cuda:0', grad_fn=<NllLossBackward0>)
900 tensor(6.7530, device='cuda:0', grad_fn=<NllLossBackward0>)
1000 tensor(6.5542, device='cuda:0', grad_fn=<NllLossBackward0>)
1100 tensor(6.5068, device='cuda:0', grad_fn=<NllLossBackward0>)
1200 tensor(6.7081, device='cuda:0', grad_fn=<NllLossBackward0>)
1300 tensor(6.2363, device='cuda:0', grad_fn=<NllLossBackward0>)
1400 tensor(6.5277, device='cuda:0', grad_fn=<NllLossBackward0>)
1500 tensor(6.5607, device='cuda:0',

12600 tensor(6.5779, device='cuda:0', grad_fn=<NllLossBackward0>)
12700 tensor(6.1300, device='cuda:0', grad_fn=<NllLossBackward0>)
12800 tensor(6.3179, device='cuda:0', grad_fn=<NllLossBackward0>)
12900 tensor(6.5471, device='cuda:0', grad_fn=<NllLossBackward0>)
13000 tensor(6.2621, device='cuda:0', grad_fn=<NllLossBackward0>)
13100 tensor(6.4863, device='cuda:0', grad_fn=<NllLossBackward0>)
13200 tensor(6.4671, device='cuda:0', grad_fn=<NllLossBackward0>)
13300 tensor(6.5966, device='cuda:0', grad_fn=<NllLossBackward0>)
13400 tensor(6.3855, device='cuda:0', grad_fn=<NllLossBackward0>)
13500 tensor(6.4136, device='cuda:0', grad_fn=<NllLossBackward0>)
13600 tensor(6.4274, device='cuda:0', grad_fn=<NllLossBackward0>)
13700 tensor(6.3050, device='cuda:0', grad_fn=<NllLossBackward0>)
13800 tensor(6.4028, device='cuda:0', grad_fn=<NllLossBackward0>)
13900 tensor(6.1994, device='cuda:0', grad_fn=<NllLossBackward0>)
14000 tensor(6.2238, device='cuda:0', grad_fn=<NllLossBackward0>)
14100 tens

25100 tensor(6.5767, device='cuda:0', grad_fn=<NllLossBackward0>)
25200 tensor(6.4680, device='cuda:0', grad_fn=<NllLossBackward0>)
25300 tensor(6.4083, device='cuda:0', grad_fn=<NllLossBackward0>)
25400 tensor(6.2756, device='cuda:0', grad_fn=<NllLossBackward0>)
25500 tensor(6.0596, device='cuda:0', grad_fn=<NllLossBackward0>)
25600 tensor(6.5235, device='cuda:0', grad_fn=<NllLossBackward0>)
25700 tensor(6.3478, device='cuda:0', grad_fn=<NllLossBackward0>)
25800 tensor(6.3905, device='cuda:0', grad_fn=<NllLossBackward0>)
25900 tensor(6.7624, device='cuda:0', grad_fn=<NllLossBackward0>)
26000 tensor(6.4832, device='cuda:0', grad_fn=<NllLossBackward0>)
26100 tensor(6.4504, device='cuda:0', grad_fn=<NllLossBackward0>)
26200 tensor(6.1166, device='cuda:0', grad_fn=<NllLossBackward0>)
26300 tensor(6.2660, device='cuda:0', grad_fn=<NllLossBackward0>)
26400 tensor(6.2220, device='cuda:0', grad_fn=<NllLossBackward0>)
26500 tensor(6.3001, device='cuda:0', grad_fn=<NllLossBackward0>)
26600 tens

37600 tensor(6.3473, device='cuda:0', grad_fn=<NllLossBackward0>)
37700 tensor(6.3019, device='cuda:0', grad_fn=<NllLossBackward0>)
37800 tensor(6.3526, device='cuda:0', grad_fn=<NllLossBackward0>)
37900 tensor(6.3167, device='cuda:0', grad_fn=<NllLossBackward0>)
38000 tensor(6.5604, device='cuda:0', grad_fn=<NllLossBackward0>)
38100 tensor(6.2682, device='cuda:0', grad_fn=<NllLossBackward0>)
38200 tensor(6.3246, device='cuda:0', grad_fn=<NllLossBackward0>)
38300 tensor(6.4815, device='cuda:0', grad_fn=<NllLossBackward0>)
38400 tensor(6.3199, device='cuda:0', grad_fn=<NllLossBackward0>)
38500 tensor(6.3742, device='cuda:0', grad_fn=<NllLossBackward0>)
38600 tensor(6.3012, device='cuda:0', grad_fn=<NllLossBackward0>)
38700 tensor(6.2586, device='cuda:0', grad_fn=<NllLossBackward0>)
38800 tensor(6.3830, device='cuda:0', grad_fn=<NllLossBackward0>)
38900 tensor(6.4648, device='cuda:0', grad_fn=<NllLossBackward0>)
39000 tensor(6.2475, device='cuda:0', grad_fn=<NllLossBackward0>)
39100 tens

50100 tensor(6.3152, device='cuda:0', grad_fn=<NllLossBackward0>)
50200 tensor(5.8161, device='cuda:0', grad_fn=<NllLossBackward0>)
50300 tensor(6.1519, device='cuda:0', grad_fn=<NllLossBackward0>)
50400 tensor(6.2640, device='cuda:0', grad_fn=<NllLossBackward0>)
50500 tensor(6.6373, device='cuda:0', grad_fn=<NllLossBackward0>)
50600 tensor(6.0610, device='cuda:0', grad_fn=<NllLossBackward0>)
50700 tensor(6.1604, device='cuda:0', grad_fn=<NllLossBackward0>)
50800 tensor(6.0850, device='cuda:0', grad_fn=<NllLossBackward0>)
50900 tensor(6.5230, device='cuda:0', grad_fn=<NllLossBackward0>)
51000 tensor(6.3261, device='cuda:0', grad_fn=<NllLossBackward0>)
51100 tensor(6.1690, device='cuda:0', grad_fn=<NllLossBackward0>)
51200 tensor(6.3807, device='cuda:0', grad_fn=<NllLossBackward0>)
51300 tensor(6.1361, device='cuda:0', grad_fn=<NllLossBackward0>)
51400 tensor(6.4120, device='cuda:0', grad_fn=<NllLossBackward0>)
51500 tensor(6.1421, device='cuda:0', grad_fn=<NllLossBackward0>)
51600 tens

62600 tensor(6.3776, device='cuda:0', grad_fn=<NllLossBackward0>)
62700 tensor(6.2917, device='cuda:0', grad_fn=<NllLossBackward0>)
62800 tensor(6.0079, device='cuda:0', grad_fn=<NllLossBackward0>)
62900 tensor(6.4841, device='cuda:0', grad_fn=<NllLossBackward0>)
63000 tensor(6.4510, device='cuda:0', grad_fn=<NllLossBackward0>)
63100 tensor(6.3967, device='cuda:0', grad_fn=<NllLossBackward0>)
63200 tensor(6.3568, device='cuda:0', grad_fn=<NllLossBackward0>)
63300 tensor(6.1641, device='cuda:0', grad_fn=<NllLossBackward0>)
63400 tensor(6.2656, device='cuda:0', grad_fn=<NllLossBackward0>)
63500 tensor(6.2119, device='cuda:0', grad_fn=<NllLossBackward0>)
63600 tensor(6.3500, device='cuda:0', grad_fn=<NllLossBackward0>)
63700 tensor(6.5353, device='cuda:0', grad_fn=<NllLossBackward0>)
63800 tensor(6.3988, device='cuda:0', grad_fn=<NllLossBackward0>)
63900 tensor(6.4113, device='cuda:0', grad_fn=<NllLossBackward0>)
64000 tensor(5.9131, device='cuda:0', grad_fn=<NllLossBackward0>)
64100 tens

75100 tensor(6.4348, device='cuda:0', grad_fn=<NllLossBackward0>)
75200 tensor(6.2299, device='cuda:0', grad_fn=<NllLossBackward0>)
75300 tensor(6.3492, device='cuda:0', grad_fn=<NllLossBackward0>)
75400 tensor(6.5882, device='cuda:0', grad_fn=<NllLossBackward0>)
75500 tensor(6.2069, device='cuda:0', grad_fn=<NllLossBackward0>)
75600 tensor(6.5318, device='cuda:0', grad_fn=<NllLossBackward0>)
75700 tensor(6.1249, device='cuda:0', grad_fn=<NllLossBackward0>)
75800 tensor(6.3609, device='cuda:0', grad_fn=<NllLossBackward0>)
75900 tensor(6.4399, device='cuda:0', grad_fn=<NllLossBackward0>)
76000 tensor(6.4117, device='cuda:0', grad_fn=<NllLossBackward0>)
76100 tensor(6.3236, device='cuda:0', grad_fn=<NllLossBackward0>)
76200 tensor(6.1960, device='cuda:0', grad_fn=<NllLossBackward0>)
76300 tensor(6.3030, device='cuda:0', grad_fn=<NllLossBackward0>)
76400 tensor(6.7321, device='cuda:0', grad_fn=<NllLossBackward0>)
76500 tensor(6.4889, device='cuda:0', grad_fn=<NllLossBackward0>)
76600 tens

87600 tensor(6.1922, device='cuda:0', grad_fn=<NllLossBackward0>)
87700 tensor(6.3410, device='cuda:0', grad_fn=<NllLossBackward0>)
87800 tensor(6.5634, device='cuda:0', grad_fn=<NllLossBackward0>)
87900 tensor(6.3292, device='cuda:0', grad_fn=<NllLossBackward0>)
88000 tensor(6.4881, device='cuda:0', grad_fn=<NllLossBackward0>)
88100 tensor(6.1968, device='cuda:0', grad_fn=<NllLossBackward0>)
88200 tensor(6.0463, device='cuda:0', grad_fn=<NllLossBackward0>)
88300 tensor(6.0094, device='cuda:0', grad_fn=<NllLossBackward0>)
88400 tensor(6.2273, device='cuda:0', grad_fn=<NllLossBackward0>)
88500 tensor(6.2220, device='cuda:0', grad_fn=<NllLossBackward0>)
88600 tensor(6.4040, device='cuda:0', grad_fn=<NllLossBackward0>)
88700 tensor(6.5188, device='cuda:0', grad_fn=<NllLossBackward0>)
88800 tensor(6.2047, device='cuda:0', grad_fn=<NllLossBackward0>)
88900 tensor(6.3574, device='cuda:0', grad_fn=<NllLossBackward0>)
89000 tensor(6.2601, device='cuda:0', grad_fn=<NllLossBackward0>)
89100 tens

100100 tensor(6.6032, device='cuda:0', grad_fn=<NllLossBackward0>)
100200 tensor(6.3257, device='cuda:0', grad_fn=<NllLossBackward0>)
100300 tensor(6.3809, device='cuda:0', grad_fn=<NllLossBackward0>)
100400 tensor(6.1088, device='cuda:0', grad_fn=<NllLossBackward0>)
100500 tensor(6.3217, device='cuda:0', grad_fn=<NllLossBackward0>)
100600 tensor(6.2402, device='cuda:0', grad_fn=<NllLossBackward0>)
100700 tensor(6.5221, device='cuda:0', grad_fn=<NllLossBackward0>)
100800 tensor(6.3372, device='cuda:0', grad_fn=<NllLossBackward0>)
100900 tensor(6.3466, device='cuda:0', grad_fn=<NllLossBackward0>)
101000 tensor(6.2795, device='cuda:0', grad_fn=<NllLossBackward0>)
101100 tensor(6.3551, device='cuda:0', grad_fn=<NllLossBackward0>)
101200 tensor(6.1093, device='cuda:0', grad_fn=<NllLossBackward0>)
101300 tensor(6.0571, device='cuda:0', grad_fn=<NllLossBackward0>)
101400 tensor(6.2678, device='cuda:0', grad_fn=<NllLossBackward0>)
101500 tensor(6.3374, device='cuda:0', grad_fn=<NllLossBackwar

112400 tensor(6.2920, device='cuda:0', grad_fn=<NllLossBackward0>)
112500 tensor(6.0232, device='cuda:0', grad_fn=<NllLossBackward0>)
112600 tensor(6.1691, device='cuda:0', grad_fn=<NllLossBackward0>)
112700 tensor(6.2214, device='cuda:0', grad_fn=<NllLossBackward0>)
112800 tensor(6.4647, device='cuda:0', grad_fn=<NllLossBackward0>)
112900 tensor(6.4600, device='cuda:0', grad_fn=<NllLossBackward0>)
113000 tensor(6.1718, device='cuda:0', grad_fn=<NllLossBackward0>)
113100 tensor(6.2358, device='cuda:0', grad_fn=<NllLossBackward0>)
113200 tensor(6.3690, device='cuda:0', grad_fn=<NllLossBackward0>)
113300 tensor(6.3420, device='cuda:0', grad_fn=<NllLossBackward0>)
113400 tensor(5.6514, device='cuda:0', grad_fn=<NllLossBackward0>)
113500 tensor(6.3852, device='cuda:0', grad_fn=<NllLossBackward0>)
113600 tensor(6.4675, device='cuda:0', grad_fn=<NllLossBackward0>)
113700 tensor(6.1993, device='cuda:0', grad_fn=<NllLossBackward0>)
113800 tensor(6.0725, device='cuda:0', grad_fn=<NllLossBackwar

124700 tensor(6.2165, device='cuda:0', grad_fn=<NllLossBackward0>)
124800 tensor(6.0142, device='cuda:0', grad_fn=<NllLossBackward0>)
124900 tensor(6.3649, device='cuda:0', grad_fn=<NllLossBackward0>)
125000 tensor(6.1046, device='cuda:0', grad_fn=<NllLossBackward0>)
125100 tensor(6.6507, device='cuda:0', grad_fn=<NllLossBackward0>)
125200 tensor(6.3295, device='cuda:0', grad_fn=<NllLossBackward0>)
125300 tensor(6.4071, device='cuda:0', grad_fn=<NllLossBackward0>)
125400 tensor(6.4771, device='cuda:0', grad_fn=<NllLossBackward0>)
125500 tensor(6.5995, device='cuda:0', grad_fn=<NllLossBackward0>)
125600 tensor(5.8743, device='cuda:0', grad_fn=<NllLossBackward0>)
125700 tensor(6.2433, device='cuda:0', grad_fn=<NllLossBackward0>)
125800 tensor(6.1171, device='cuda:0', grad_fn=<NllLossBackward0>)
125900 tensor(5.7314, device='cuda:0', grad_fn=<NllLossBackward0>)
126000 tensor(6.5950, device='cuda:0', grad_fn=<NllLossBackward0>)
126100 tensor(6.4330, device='cuda:0', grad_fn=<NllLossBackwar

137000 tensor(6.1552, device='cuda:0', grad_fn=<NllLossBackward0>)
137100 tensor(6.3681, device='cuda:0', grad_fn=<NllLossBackward0>)
137200 tensor(6.2283, device='cuda:0', grad_fn=<NllLossBackward0>)
137300 tensor(6.5981, device='cuda:0', grad_fn=<NllLossBackward0>)
137400 tensor(6.3275, device='cuda:0', grad_fn=<NllLossBackward0>)
137500 tensor(6.3605, device='cuda:0', grad_fn=<NllLossBackward0>)
137600 tensor(6.1404, device='cuda:0', grad_fn=<NllLossBackward0>)
137700 tensor(5.9959, device='cuda:0', grad_fn=<NllLossBackward0>)
137800 tensor(5.8553, device='cuda:0', grad_fn=<NllLossBackward0>)
137900 tensor(6.1422, device='cuda:0', grad_fn=<NllLossBackward0>)
138000 tensor(6.3613, device='cuda:0', grad_fn=<NllLossBackward0>)
138100 tensor(6.4806, device='cuda:0', grad_fn=<NllLossBackward0>)
138200 tensor(6.3342, device='cuda:0', grad_fn=<NllLossBackward0>)
138300 tensor(6.1325, device='cuda:0', grad_fn=<NllLossBackward0>)
138400 tensor(6.3189, device='cuda:0', grad_fn=<NllLossBackwar

149300 tensor(6.0544, device='cuda:0', grad_fn=<NllLossBackward0>)
149400 tensor(6.3166, device='cuda:0', grad_fn=<NllLossBackward0>)
149500 tensor(6.1771, device='cuda:0', grad_fn=<NllLossBackward0>)
149600 tensor(6.5291, device='cuda:0', grad_fn=<NllLossBackward0>)
149700 tensor(6.3477, device='cuda:0', grad_fn=<NllLossBackward0>)
149800 tensor(6.5005, device='cuda:0', grad_fn=<NllLossBackward0>)
149900 tensor(6.0765, device='cuda:0', grad_fn=<NllLossBackward0>)
150000 tensor(6.2168, device='cuda:0', grad_fn=<NllLossBackward0>)
150100 tensor(5.9786, device='cuda:0', grad_fn=<NllLossBackward0>)
150200 tensor(6.3884, device='cuda:0', grad_fn=<NllLossBackward0>)
150300 tensor(6.3308, device='cuda:0', grad_fn=<NllLossBackward0>)
150400 tensor(6.2943, device='cuda:0', grad_fn=<NllLossBackward0>)
150500 tensor(5.9515, device='cuda:0', grad_fn=<NllLossBackward0>)
150600 tensor(6.1360, device='cuda:0', grad_fn=<NllLossBackward0>)
150700 tensor(6.1946, device='cuda:0', grad_fn=<NllLossBackwar

161600 tensor(6.4305, device='cuda:0', grad_fn=<NllLossBackward0>)
161700 tensor(6.1910, device='cuda:0', grad_fn=<NllLossBackward0>)
161800 tensor(6.0620, device='cuda:0', grad_fn=<NllLossBackward0>)
161900 tensor(6.3450, device='cuda:0', grad_fn=<NllLossBackward0>)
162000 tensor(6.0384, device='cuda:0', grad_fn=<NllLossBackward0>)
162100 tensor(6.2913, device='cuda:0', grad_fn=<NllLossBackward0>)
162200 tensor(6.4014, device='cuda:0', grad_fn=<NllLossBackward0>)
162300 tensor(6.1961, device='cuda:0', grad_fn=<NllLossBackward0>)
162400 tensor(6.3429, device='cuda:0', grad_fn=<NllLossBackward0>)
162500 tensor(6.1807, device='cuda:0', grad_fn=<NllLossBackward0>)
162600 tensor(6.1816, device='cuda:0', grad_fn=<NllLossBackward0>)
162700 tensor(6.5639, device='cuda:0', grad_fn=<NllLossBackward0>)
162800 tensor(6.1019, device='cuda:0', grad_fn=<NllLossBackward0>)
162900 tensor(6.1725, device='cuda:0', grad_fn=<NllLossBackward0>)
163000 tensor(6.5369, device='cuda:0', grad_fn=<NllLossBackwar

173900 tensor(6.1105, device='cuda:0', grad_fn=<NllLossBackward0>)
174000 tensor(6.5883, device='cuda:0', grad_fn=<NllLossBackward0>)
174100 tensor(6.2517, device='cuda:0', grad_fn=<NllLossBackward0>)
174200 tensor(6.3595, device='cuda:0', grad_fn=<NllLossBackward0>)
174300 tensor(6.0311, device='cuda:0', grad_fn=<NllLossBackward0>)
174400 tensor(5.8500, device='cuda:0', grad_fn=<NllLossBackward0>)
174500 tensor(6.0308, device='cuda:0', grad_fn=<NllLossBackward0>)
174600 tensor(6.4416, device='cuda:0', grad_fn=<NllLossBackward0>)
174700 tensor(6.3174, device='cuda:0', grad_fn=<NllLossBackward0>)
174800 tensor(6.0302, device='cuda:0', grad_fn=<NllLossBackward0>)
174900 tensor(5.8741, device='cuda:0', grad_fn=<NllLossBackward0>)
175000 tensor(6.1169, device='cuda:0', grad_fn=<NllLossBackward0>)
175100 tensor(6.5659, device='cuda:0', grad_fn=<NllLossBackward0>)
175200 tensor(6.2329, device='cuda:0', grad_fn=<NllLossBackward0>)
175300 tensor(6.2470, device='cuda:0', grad_fn=<NllLossBackwar

186200 tensor(6.3534, device='cuda:0', grad_fn=<NllLossBackward0>)
186300 tensor(6.0138, device='cuda:0', grad_fn=<NllLossBackward0>)
186400 tensor(6.0890, device='cuda:0', grad_fn=<NllLossBackward0>)
186500 tensor(6.2881, device='cuda:0', grad_fn=<NllLossBackward0>)
186600 tensor(6.0929, device='cuda:0', grad_fn=<NllLossBackward0>)
186700 tensor(6.0125, device='cuda:0', grad_fn=<NllLossBackward0>)
186800 tensor(5.6474, device='cuda:0', grad_fn=<NllLossBackward0>)
186900 tensor(6.4010, device='cuda:0', grad_fn=<NllLossBackward0>)
187000 tensor(6.3343, device='cuda:0', grad_fn=<NllLossBackward0>)
187100 tensor(5.7301, device='cuda:0', grad_fn=<NllLossBackward0>)
187200 tensor(6.2032, device='cuda:0', grad_fn=<NllLossBackward0>)
187300 tensor(6.2289, device='cuda:0', grad_fn=<NllLossBackward0>)
187400 tensor(6.4082, device='cuda:0', grad_fn=<NllLossBackward0>)
187500 tensor(6.2096, device='cuda:0', grad_fn=<NllLossBackward0>)
187600 tensor(6.4854, device='cuda:0', grad_fn=<NllLossBackwar

198500 tensor(6.5452, device='cuda:0', grad_fn=<NllLossBackward0>)
198600 tensor(6.1089, device='cuda:0', grad_fn=<NllLossBackward0>)
198700 tensor(6.4479, device='cuda:0', grad_fn=<NllLossBackward0>)
198800 tensor(6.3462, device='cuda:0', grad_fn=<NllLossBackward0>)
198900 tensor(5.9453, device='cuda:0', grad_fn=<NllLossBackward0>)
199000 tensor(6.3335, device='cuda:0', grad_fn=<NllLossBackward0>)
199100 tensor(6.4890, device='cuda:0', grad_fn=<NllLossBackward0>)
199200 tensor(6.1730, device='cuda:0', grad_fn=<NllLossBackward0>)
199300 tensor(6.2879, device='cuda:0', grad_fn=<NllLossBackward0>)
199400 tensor(6.0476, device='cuda:0', grad_fn=<NllLossBackward0>)
199500 tensor(6.1549, device='cuda:0', grad_fn=<NllLossBackward0>)
199600 tensor(6.4688, device='cuda:0', grad_fn=<NllLossBackward0>)
199700 tensor(6.2297, device='cuda:0', grad_fn=<NllLossBackward0>)
199800 tensor(6.4936, device='cuda:0', grad_fn=<NllLossBackward0>)
199900 tensor(6.4739, device='cuda:0', grad_fn=<NllLossBackwar

210800 tensor(6.4275, device='cuda:0', grad_fn=<NllLossBackward0>)
210900 tensor(6.1214, device='cuda:0', grad_fn=<NllLossBackward0>)
211000 tensor(6.0207, device='cuda:0', grad_fn=<NllLossBackward0>)
211100 tensor(6.1209, device='cuda:0', grad_fn=<NllLossBackward0>)
211200 tensor(6.2109, device='cuda:0', grad_fn=<NllLossBackward0>)
211300 tensor(6.0009, device='cuda:0', grad_fn=<NllLossBackward0>)
211400 tensor(6.2715, device='cuda:0', grad_fn=<NllLossBackward0>)
211500 tensor(6.4340, device='cuda:0', grad_fn=<NllLossBackward0>)
211600 tensor(6.4781, device='cuda:0', grad_fn=<NllLossBackward0>)
211700 tensor(6.2207, device='cuda:0', grad_fn=<NllLossBackward0>)
211800 tensor(6.2370, device='cuda:0', grad_fn=<NllLossBackward0>)
211900 tensor(5.9837, device='cuda:0', grad_fn=<NllLossBackward0>)
212000 tensor(6.2359, device='cuda:0', grad_fn=<NllLossBackward0>)
212100 tensor(6.4122, device='cuda:0', grad_fn=<NllLossBackward0>)


In [9]:
torch.save(model.state_dict(), 'model/model1.bin')

In [23]:
device = 'cuda'
model = SimpleTrigramNeuralLanguageModel(vocab_size, embed_size, hidden_size).to(device)
model.load_state_dict(torch.load('model/model1.bin'))
model.eval()

def predict(words):
    ixs = torch.tensor(train_dataset.vocab.forward(['with'])).to(device)
    predictions = model(ixs)
    top = torch.topk(out[0], 30)
    top_indices = top.indices.tolist()
    top_probs = top.values.tolist()
    top_words = train_dataset.vocab.lookup_tokens(top_indices)
    top_preds = list(zip(top_words, top_indices, top_probs))
    
    total_prob = 0.0
    pred_str = ''
    for word, _, prob in top_preds:
        if word != '<unk>':
            pred_str += f'{word}:{prob} '
            total_prob += prob
    pred_str += f':{1 - total_prob}'
    
    return pred_str

In [24]:
dev_data = pd.read_csv('dev-0/in.tsv.xz', sep='\t', error_bad_lines=False, warn_bad_lines=False, header=None, quoting=csv.QUOTE_NONE)
test_data = pd.read_csv('test-A/in.tsv.xz', sep='\t', error_bad_lines=False, warn_bad_lines=False, header=None, quoting=csv.QUOTE_NONE)

In [27]:
from nltk import word_tokenize

with open('dev-0/out.tsv', 'w') as file:
    for index, row in dev_data.iterrows():
        left_text = clean_text(str(row[6]))
        left_words = word_tokenize(left_text)
        if len(left_words) < 3:
            prediction = ':1.0'
        else:
            prediction = predict(left_words[-2:])
        file.write(prediction + '\n')

In [28]:
with open('test-A/out.tsv', 'w') as file:
    for index, row in test_data.iterrows():
        left_text = clean_text(str(row[6]))
        left_words = word_tokenize(left_text)
        if len(left_words) < 3:
            prediction = ':1.0'
        else:
            prediction = predict(left_words[-2:])
        file.write(prediction + '\n')