challenging-america-word-ga.../simple_neural_network.ipynb

92 KiB

IMPORTS

import regex as re
import sys
from torchtext.vocab import build_vocab_from_iterator
import lzma
from torch.utils.data import IterableDataset
import itertools
from torch import nn
import torch
import pickle
from torch.utils.data import DataLoader

print(torch.backends.mps.is_available())
print(torch.backends.mps.is_built())
True
True

FUNCTIONS

def get_words_from_line(line):
  line = line.rstrip()
  yield '<s>'
  for t in line.split(' '):
      yield t
  yield '</s>'

def get_word_lines_from_file(file_name):
    n = 0
    with lzma.open(file_name, 'r') as fh:
        for line in fh:
            n += 1
            if n % 1000 == 0:
                print(n ,file=sys.stderr)
            yield get_words_from_line(line.decode('utf-8'))

def look_ahead_iterator(gen):
   prev = None
   for item in gen:
      if prev is not None:
         yield (prev, item)
      prev = item

def clean(text):
    text = str(text).lower().replace('-\\\\n', '').replace('\\\\n', ' ').replace('-', '').replace('\'s', ' is').replace('\'re', ' are').replace('\'m', ' am').replace('\'ve', ' have').replace('\'ll', ' will')
    text = re.sub(r'\p{P}', '', text)
    return text

def predict(word, model, vocab):
    try:
        ixs = torch.tensor(vocab.forward([word])).to(device)
    except:
        ixs = torch.tensor(vocab.forward(['<unk>'])).to(device)
        word = '<unk>'
    out = model(ixs)
    top = torch.topk(out[0], 300)
    top_indices = top.indices.tolist()
    top_probs = top.values.tolist()
    top_words = vocab.lookup_tokens(top_indices)
    prob_list = list(zip(top_words, top_probs))
    for index, element in enumerate(prob_list):
        unk = None
        if '<unk>' in element:
            unk = prob_list.pop(index)
            prob_list.append(('', unk[1]))
            break
    if unk is None:
        prob_list[-1] = ('', prob_list[-1][1])
    return ' '.join([f'{x[0]}:{x[1]}' for x in prob_list])

def predicition_for_file(model, vocab, folder, file):
    print('=' * 10, f' do prediction for {folder}/{file} ', '=' * 10)
    with lzma.open(f'{folder}/in.tsv.xz', mode='rt', encoding='utf-8') as f:
        with open(f'{folder}/out.tsv', 'w', encoding='utf-8') as fid:

            for line in f:
                separated = line.split('\t')
                before = clean(separated[6]).split()[-1]
                new_line = predict(before, model, vocab)
                fid.write(new_line + '\n')

CLASSES

class Bigrams(IterableDataset):
  def __init__(self, text_file, vocabulary_size):
      self.vocab = build_vocab_from_iterator(
         get_word_lines_from_file(text_file),
         max_tokens = vocabulary_size,
         specials = ['<unk>'])
      self.vocab.set_default_index(self.vocab['<unk>'])
      self.vocabulary_size = vocabulary_size
      self.text_file = text_file

  def __iter__(self):
     return look_ahead_iterator(
         (self.vocab[t] for t in itertools.chain.from_iterable(get_word_lines_from_file(self.text_file))))

class SimpleBigramNeuralLanguageModel(nn.Module):
  def __init__(self, vocabulary_size, embedding_size):
      super(SimpleBigramNeuralLanguageModel, self).__init__()
      self.model = nn.Sequential(
          nn.Embedding(vocabulary_size, embedding_size),
          nn.Linear(embedding_size, vocabulary_size),
          nn.Softmax()
      )

  def forward(self, x):
      return self.model(x)

PARAMETERS

vocab_size = 30000
embed_size = 1000
batch_size = 5000
device = 'mps'
path_to_training_file = './train/in.tsv.xz'
path_to_model_file = 'model_neural_network.bin'
folder_dev_0, file_dev_0 = 'dev-0', 'in.tsv.xz'
folder_test_a, file_test_a = 'test-A', 'in.tsv.xz'
path_to_vocabulary_file = 'vocabulary_neural_network.pickle'

VOCAB

vocab = build_vocab_from_iterator(
    get_word_lines_from_file(path_to_training_file),
    max_tokens = vocab_size,
    specials = ['<unk>'])

with open(path_to_vocabulary_file, 'wb') as handle:
    pickle.dump(vocab, handle, protocol=pickle.HIGHEST_PROTOCOL)
1000
2000
3000
4000
5000
6000
7000
8000
9000
10000
11000
12000
13000
14000
15000
16000
17000
18000
19000
20000
21000
22000
23000
24000
25000
26000
27000
28000
29000
30000
31000
32000
33000
34000
35000
36000
37000
38000
39000
40000
41000
42000
43000
44000
45000
46000
47000
48000
49000
50000
51000
52000
53000
54000
55000
56000
57000
58000
59000
60000
61000
62000
63000
64000
65000
66000
67000
68000
69000
70000
71000
72000
73000
74000
75000
76000
77000
78000
79000
80000
81000
82000
83000
84000
85000
86000
87000
88000
89000
90000
91000
92000
93000
94000
95000
96000
97000
98000
99000
100000
101000
102000
103000
104000
105000
106000
107000
108000
109000
110000
111000
112000
113000
114000
115000
116000
117000
118000
119000
120000
121000
122000
123000
124000
125000
126000
127000
128000
129000
130000
131000
132000
133000
134000
135000
136000
137000
138000
139000
140000
141000
142000
143000
144000
145000
146000
147000
148000
149000
150000
151000
152000
153000
154000
155000
156000
157000
158000
159000
160000
161000
162000
163000
164000
165000
166000
167000
168000
169000
170000
171000
172000
173000
174000
175000
176000
177000
178000
179000
180000
181000
182000
183000
184000
185000
186000
187000
188000
189000
190000
191000
192000
193000
194000
195000
196000
197000
198000
199000
200000
201000
202000
203000
204000
205000
206000
207000
208000
209000
210000
211000
212000
213000
214000
215000
216000
217000
218000
219000
220000
221000
222000
223000
224000
225000
226000
227000
228000
229000
230000
231000
232000
233000
234000
235000
236000
237000
238000
239000
240000
241000
242000
243000
244000
245000
246000
247000
248000
249000
250000
251000
252000
253000
254000
255000
256000
257000
258000
259000
260000
261000
262000
263000
264000
265000
266000
267000
268000
269000
270000
271000
272000
273000
274000
275000
276000
277000
278000
279000
280000
281000
282000
283000
284000
285000
286000
287000
288000
289000
290000
291000
292000
293000
294000
295000
296000
297000
298000
299000
300000
301000
302000
303000
304000
305000
306000
307000
308000
309000
310000
311000
312000
313000
314000
315000
316000
317000
318000
319000
320000
321000
322000
323000
324000
325000
326000
327000
328000
329000
330000
331000
332000
333000
334000
335000
336000
337000
338000
339000
340000
341000
342000
343000
344000
345000
346000
347000
348000
349000
350000
351000
352000
353000
354000
355000
356000
357000
358000
359000
360000
361000
362000
363000
364000
365000
366000
367000
368000
369000
370000
371000
372000
373000
374000
375000
376000
377000
378000
379000
380000
381000
382000
383000
384000
385000
386000
387000
388000
389000
390000
391000
392000
393000
394000
395000
396000
397000
398000
399000
400000
401000
402000
403000
404000
405000
406000
407000
408000
409000
410000
411000
412000
413000
414000
415000
416000
417000
418000
419000
420000
421000
422000
423000
424000
425000
426000
427000
428000
429000
430000
431000
432000

TRAIN MODEL

train_dataset = Bigrams(path_to_training_file, vocab_size)
model = SimpleBigramNeuralLanguageModel(vocab_size, embed_size).to(device)
data = DataLoader(train_dataset, batch_size=batch_size)
optimizer = torch.optim.Adam(model.parameters())
criterion = torch.nn.NLLLoss()

model.train()
step = 0
for x, y in data:
   x = x.to(device)
   y = y.to(device)
   optimizer.zero_grad()
   ypredicted = model(x)
   loss = criterion(torch.log(ypredicted), y)
   if step % 100 == 0:
      print(step, loss)
   step += 1
   loss.backward()
   optimizer.step()

torch.save(model.state_dict(), path_to_model_file)
1000
2000
3000
4000
5000
6000
7000
8000
9000
10000
11000
12000
13000
14000
15000
16000
17000
18000
19000
20000
21000
22000
23000
24000
25000
26000
27000
28000
29000
30000
31000
32000
33000
34000
35000
36000
37000
38000
39000
40000
41000
42000
43000
44000
45000
46000
47000
48000
49000
50000
51000
52000
53000
54000
55000
56000
57000
58000
59000
60000
61000
62000
63000
64000
65000
66000
67000
68000
69000
70000
71000
72000
73000
74000
75000
76000
77000
78000
79000
80000
81000
82000
83000
84000
85000
86000
87000
88000
89000
90000
91000
92000
93000
94000
95000
96000
97000
98000
99000
100000
101000
102000
103000
104000
105000
106000
107000
108000
109000
110000
111000
112000
113000
114000
115000
116000
117000
118000
119000
120000
121000
122000
123000
124000
125000
126000
127000
128000
129000
130000
131000
132000
133000
134000
135000
136000
137000
138000
139000
140000
141000
142000
143000
144000
145000
146000
147000
148000
149000
150000
151000
152000
153000
154000
155000
156000
157000
158000
159000
160000
161000
162000
163000
164000
165000
166000
167000
168000
169000
170000
171000
172000
173000
174000
175000
176000
177000
178000
179000
180000
181000
182000
183000
184000
185000
186000
187000
188000
189000
190000
191000
192000
193000
194000
195000
196000
197000
198000
199000
200000
201000
202000
203000
204000
205000
206000
207000
208000
209000
210000
211000
212000
213000
214000
215000
216000
217000
218000
219000
220000
221000
222000
223000
224000
225000
226000
227000
228000
229000
230000
231000
232000
233000
234000
235000
236000
237000
238000
239000
240000
241000
242000
243000
244000
245000
246000
247000
248000
249000
250000
251000
252000
253000
254000
255000
256000
257000
258000
259000
260000
261000
262000
263000
264000
265000
266000
267000
268000
269000
270000
271000
272000
273000
274000
275000
276000
277000
278000
279000
280000
281000
282000
283000
284000
285000
286000
287000
288000
289000
290000
291000
292000
293000
294000
295000
296000
297000
298000
299000
300000
301000
302000
303000
304000
305000
306000
307000
308000
309000
310000
311000
312000
313000
314000
315000
316000
317000
318000
319000
320000
321000
322000
323000
324000
325000
326000
327000
328000
329000
330000
331000
332000
333000
334000
335000
336000
337000
338000
339000
340000
341000
342000
343000
344000
345000
346000
347000
348000
349000
350000
351000
352000
353000
354000
355000
356000
357000
358000
359000
360000
361000
362000
363000
364000
365000
366000
367000
368000
369000
370000
371000
372000
373000
374000
375000
376000
377000
378000
379000
380000
381000
382000
383000
384000
385000
386000
387000
388000
389000
390000
391000
392000
393000
394000
395000
396000
397000
398000
399000
400000
401000
402000
403000
404000
405000
406000
407000
408000
409000
410000
411000
412000
413000
414000
415000
416000
417000
418000
419000
420000
421000
422000
423000
424000
425000
426000
427000
428000
429000
430000
431000
432000
/Users/maciej/miniconda3/envs/mj/lib/python3.11/site-packages/torch/nn/modules/container.py:217: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.
  input = module(input)
0 tensor(10.5058, device='mps:0', grad_fn=<NllLossBackward0>)
1000
100 tensor(7.3365, device='mps:0', grad_fn=<NllLossBackward0>)
2000
3000
200 tensor(6.6523, device='mps:0', grad_fn=<NllLossBackward0>)
4000
5000
300 tensor(6.1860, device='mps:0', grad_fn=<NllLossBackward0>)
6000
7000
400 tensor(6.0387, device='mps:0', grad_fn=<NllLossBackward0>)
8000
500 tensor(5.8481, device='mps:0', grad_fn=<NllLossBackward0>)
9000
10000
600 tensor(5.6081, device='mps:0', grad_fn=<NllLossBackward0>)
11000
12000
700 tensor(5.5820, device='mps:0', grad_fn=<NllLossBackward0>)
13000
14000
800 tensor(5.5111, device='mps:0', grad_fn=<NllLossBackward0>)
15000
900 tensor(5.4927, device='mps:0', grad_fn=<NllLossBackward0>)
16000
17000
1000 tensor(5.5190, device='mps:0', grad_fn=<NllLossBackward0>)
18000
19000
1100 tensor(5.5600, device='mps:0', grad_fn=<NllLossBackward0>)
20000
21000
1200 tensor(5.6395, device='mps:0', grad_fn=<NllLossBackward0>)
22000
23000
1300 tensor(5.4455, device='mps:0', grad_fn=<NllLossBackward0>)
24000
1400 tensor(5.5564, device='mps:0', grad_fn=<NllLossBackward0>)
25000
26000
1500 tensor(5.4919, device='mps:0', grad_fn=<NllLossBackward0>)
27000
28000
1600 tensor(5.2355, device='mps:0', grad_fn=<NllLossBackward0>)
29000
30000
1700 tensor(5.4107, device='mps:0', grad_fn=<NllLossBackward0>)
31000
32000
1800 tensor(5.5119, device='mps:0', grad_fn=<NllLossBackward0>)
33000
1900 tensor(5.3500, device='mps:0', grad_fn=<NllLossBackward0>)
34000
35000
2000 tensor(5.3722, device='mps:0', grad_fn=<NllLossBackward0>)
36000
37000
2100 tensor(5.2736, device='mps:0', grad_fn=<NllLossBackward0>)
38000
39000
2200 tensor(5.3808, device='mps:0', grad_fn=<NllLossBackward0>)
40000
2300 tensor(5.5186, device='mps:0', grad_fn=<NllLossBackward0>)
41000
42000
2400 tensor(5.2746, device='mps:0', grad_fn=<NllLossBackward0>)
43000
44000
2500 tensor(5.3340, device='mps:0', grad_fn=<NllLossBackward0>)
45000
46000
2600 tensor(5.4654, device='mps:0', grad_fn=<NllLossBackward0>)
47000
48000
2700 tensor(5.4318, device='mps:0', grad_fn=<NllLossBackward0>)
49000
2800 tensor(5.3528, device='mps:0', grad_fn=<NllLossBackward0>)
50000
51000
2900 tensor(5.1630, device='mps:0', grad_fn=<NllLossBackward0>)
52000
53000
3000 tensor(5.4531, device='mps:0', grad_fn=<NllLossBackward0>)
54000
55000
3100 tensor(5.4153, device='mps:0', grad_fn=<NllLossBackward0>)
56000
3200 tensor(5.3299, device='mps:0', grad_fn=<NllLossBackward0>)
57000
58000
3300 tensor(5.3637, device='mps:0', grad_fn=<NllLossBackward0>)
59000
60000
3400 tensor(5.3405, device='mps:0', grad_fn=<NllLossBackward0>)
61000
62000
3500 tensor(5.3668, device='mps:0', grad_fn=<NllLossBackward0>)
63000
3600 tensor(5.4104, device='mps:0', grad_fn=<NllLossBackward0>)
64000
65000
3700 tensor(5.2142, device='mps:0', grad_fn=<NllLossBackward0>)
66000
67000
3800 tensor(5.5528, device='mps:0', grad_fn=<NllLossBackward0>)
68000
69000
3900 tensor(5.1879, device='mps:0', grad_fn=<NllLossBackward0>)
70000
71000
4000 tensor(5.2014, device='mps:0', grad_fn=<NllLossBackward0>)
72000
4100 tensor(5.4020, device='mps:0', grad_fn=<NllLossBackward0>)
73000
74000
4200 tensor(5.2686, device='mps:0', grad_fn=<NllLossBackward0>)
75000
76000
4300 tensor(5.3070, device='mps:0', grad_fn=<NllLossBackward0>)
77000
78000
4400 tensor(5.1891, device='mps:0', grad_fn=<NllLossBackward0>)
79000
4500 tensor(5.3085, device='mps:0', grad_fn=<NllLossBackward0>)
80000
81000
4600 tensor(5.3568, device='mps:0', grad_fn=<NllLossBackward0>)
82000
83000
4700 tensor(5.2280, device='mps:0', grad_fn=<NllLossBackward0>)
84000
85000
4800 tensor(5.2878, device='mps:0', grad_fn=<NllLossBackward0>)
86000
87000
4900 tensor(5.1588, device='mps:0', grad_fn=<NllLossBackward0>)
88000
5000 tensor(5.1523, device='mps:0', grad_fn=<NllLossBackward0>)
89000
90000
5100 tensor(5.2101, device='mps:0', grad_fn=<NllLossBackward0>)
91000
92000
5200 tensor(5.2949, device='mps:0', grad_fn=<NllLossBackward0>)
93000
94000
5300 tensor(5.3186, device='mps:0', grad_fn=<NllLossBackward0>)
95000
5400 tensor(5.2580, device='mps:0', grad_fn=<NllLossBackward0>)
96000
97000
5500 tensor(5.3632, device='mps:0', grad_fn=<NllLossBackward0>)
98000
99000
5600 tensor(5.3885, device='mps:0', grad_fn=<NllLossBackward0>)
100000
101000
5700 tensor(5.2640, device='mps:0', grad_fn=<NllLossBackward0>)
102000
103000
5800 tensor(5.4444, device='mps:0', grad_fn=<NllLossBackward0>)
104000
5900 tensor(5.1981, device='mps:0', grad_fn=<NllLossBackward0>)
105000
106000
6000 tensor(5.2765, device='mps:0', grad_fn=<NllLossBackward0>)
107000
108000
6100 tensor(5.3015, device='mps:0', grad_fn=<NllLossBackward0>)
109000
110000
6200 tensor(5.1958, device='mps:0', grad_fn=<NllLossBackward0>)
111000
6300 tensor(5.1862, device='mps:0', grad_fn=<NllLossBackward0>)
112000
113000
6400 tensor(5.4609, device='mps:0', grad_fn=<NllLossBackward0>)
114000
115000
6500 tensor(5.2700, device='mps:0', grad_fn=<NllLossBackward0>)
116000
117000
6600 tensor(5.3814, device='mps:0', grad_fn=<NllLossBackward0>)
118000
119000
6700 tensor(5.2443, device='mps:0', grad_fn=<NllLossBackward0>)
120000
6800 tensor(5.2292, device='mps:0', grad_fn=<NllLossBackward0>)
121000
122000
6900 tensor(5.2252, device='mps:0', grad_fn=<NllLossBackward0>)
123000
124000
7000 tensor(5.3240, device='mps:0', grad_fn=<NllLossBackward0>)
125000
126000
7100 tensor(5.3584, device='mps:0', grad_fn=<NllLossBackward0>)
127000
7200 tensor(5.2038, device='mps:0', grad_fn=<NllLossBackward0>)
128000
129000
7300 tensor(5.3306, device='mps:0', grad_fn=<NllLossBackward0>)
130000
131000
7400 tensor(5.3824, device='mps:0', grad_fn=<NllLossBackward0>)
132000
133000
7500 tensor(5.1708, device='mps:0', grad_fn=<NllLossBackward0>)
134000
135000
7600 tensor(5.3388, device='mps:0', grad_fn=<NllLossBackward0>)
136000
7700 tensor(5.2014, device='mps:0', grad_fn=<NllLossBackward0>)
137000
138000
7800 tensor(5.3407, device='mps:0', grad_fn=<NllLossBackward0>)
139000
140000
7900 tensor(5.3078, device='mps:0', grad_fn=<NllLossBackward0>)
141000
142000
8000 tensor(5.0961, device='mps:0', grad_fn=<NllLossBackward0>)
143000
8100 tensor(5.1313, device='mps:0', grad_fn=<NllLossBackward0>)
144000
145000
8200 tensor(5.2008, device='mps:0', grad_fn=<NllLossBackward0>)
146000
147000
8300 tensor(5.1277, device='mps:0', grad_fn=<NllLossBackward0>)
148000
149000
8400 tensor(5.3875, device='mps:0', grad_fn=<NllLossBackward0>)
150000
151000
8500 tensor(5.3107, device='mps:0', grad_fn=<NllLossBackward0>)
152000
8600 tensor(5.3640, device='mps:0', grad_fn=<NllLossBackward0>)
153000
154000
8700 tensor(5.1869, device='mps:0', grad_fn=<NllLossBackward0>)
155000
156000
8800 tensor(5.0180, device='mps:0', grad_fn=<NllLossBackward0>)
157000
158000
8900 tensor(5.1767, device='mps:0', grad_fn=<NllLossBackward0>)
159000
9000 tensor(5.3253, device='mps:0', grad_fn=<NllLossBackward0>)
160000
161000
9100 tensor(5.1971, device='mps:0', grad_fn=<NllLossBackward0>)
162000
163000
9200 tensor(5.2071, device='mps:0', grad_fn=<NllLossBackward0>)
164000
165000
9300 tensor(5.1244, device='mps:0', grad_fn=<NllLossBackward0>)
166000
167000
9400 tensor(5.2198, device='mps:0', grad_fn=<NllLossBackward0>)
168000
9500 tensor(5.3042, device='mps:0', grad_fn=<NllLossBackward0>)
169000
170000
9600 tensor(5.3171, device='mps:0', grad_fn=<NllLossBackward0>)
171000
172000
9700 tensor(5.1956, device='mps:0', grad_fn=<NllLossBackward0>)
173000
174000
9800 tensor(5.1559, device='mps:0', grad_fn=<NllLossBackward0>)
175000
9900 tensor(5.1519, device='mps:0', grad_fn=<NllLossBackward0>)
176000
177000
10000 tensor(5.3396, device='mps:0', grad_fn=<NllLossBackward0>)
178000
179000
10100 tensor(5.2106, device='mps:0', grad_fn=<NllLossBackward0>)
180000
181000
10200 tensor(5.3356, device='mps:0', grad_fn=<NllLossBackward0>)
182000
183000
10300 tensor(5.2105, device='mps:0', grad_fn=<NllLossBackward0>)
184000
10400 tensor(5.0844, device='mps:0', grad_fn=<NllLossBackward0>)
185000
186000
10500 tensor(5.3788, device='mps:0', grad_fn=<NllLossBackward0>)
187000
188000
10600 tensor(5.1145, device='mps:0', grad_fn=<NllLossBackward0>)
189000
190000
10700 tensor(5.2610, device='mps:0', grad_fn=<NllLossBackward0>)
191000
10800 tensor(5.2560, device='mps:0', grad_fn=<NllLossBackward0>)
192000
193000
10900 tensor(5.2565, device='mps:0', grad_fn=<NllLossBackward0>)
194000
195000
11000 tensor(5.2770, device='mps:0', grad_fn=<NllLossBackward0>)
196000
197000
11100 tensor(5.1193, device='mps:0', grad_fn=<NllLossBackward0>)
198000
199000
11200 tensor(5.1823, device='mps:0', grad_fn=<NllLossBackward0>)
200000
11300 tensor(5.3099, device='mps:0', grad_fn=<NllLossBackward0>)
201000
202000
11400 tensor(5.2330, device='mps:0', grad_fn=<NllLossBackward0>)
203000
204000
11500 tensor(5.1722, device='mps:0', grad_fn=<NllLossBackward0>)
205000
206000
11600 tensor(5.2136, device='mps:0', grad_fn=<NllLossBackward0>)
207000
208000
11700 tensor(5.3126, device='mps:0', grad_fn=<NllLossBackward0>)
209000
11800 tensor(5.1057, device='mps:0', grad_fn=<NllLossBackward0>)
210000
211000
11900 tensor(5.2419, device='mps:0', grad_fn=<NllLossBackward0>)
212000
213000
12000 tensor(5.2434, device='mps:0', grad_fn=<NllLossBackward0>)
214000
215000
12100 tensor(5.1692, device='mps:0', grad_fn=<NllLossBackward0>)
216000
12200 tensor(5.2075, device='mps:0', grad_fn=<NllLossBackward0>)
217000
218000
12300 tensor(5.1290, device='mps:0', grad_fn=<NllLossBackward0>)
219000
220000
12400 tensor(5.2380, device='mps:0', grad_fn=<NllLossBackward0>)
221000
222000
12500 tensor(5.2779, device='mps:0', grad_fn=<NllLossBackward0>)
223000
224000
12600 tensor(5.3369, device='mps:0', grad_fn=<NllLossBackward0>)
225000
12700 tensor(5.2351, device='mps:0', grad_fn=<NllLossBackward0>)
226000
227000
12800 tensor(5.2434, device='mps:0', grad_fn=<NllLossBackward0>)
228000
229000
12900 tensor(5.1963, device='mps:0', grad_fn=<NllLossBackward0>)
230000
231000
13000 tensor(5.1363, device='mps:0', grad_fn=<NllLossBackward0>)
232000
13100 tensor(5.1915, device='mps:0', grad_fn=<NllLossBackward0>)
233000
234000
13200 tensor(5.1264, device='mps:0', grad_fn=<NllLossBackward0>)
235000
236000
13300 tensor(5.1468, device='mps:0', grad_fn=<NllLossBackward0>)
237000
238000
13400 tensor(5.3026, device='mps:0', grad_fn=<NllLossBackward0>)
239000
13500 tensor(5.2925, device='mps:0', grad_fn=<NllLossBackward0>)
240000
241000
13600 tensor(5.1511, device='mps:0', grad_fn=<NllLossBackward0>)
242000
243000
13700 tensor(5.4282, device='mps:0', grad_fn=<NllLossBackward0>)
244000
245000
13800 tensor(5.2730, device='mps:0', grad_fn=<NllLossBackward0>)
246000
247000
13900 tensor(5.2097, device='mps:0', grad_fn=<NllLossBackward0>)
248000
14000 tensor(5.2728, device='mps:0', grad_fn=<NllLossBackward0>)
249000
250000
14100 tensor(5.2134, device='mps:0', grad_fn=<NllLossBackward0>)
251000
252000
14200 tensor(5.1931, device='mps:0', grad_fn=<NllLossBackward0>)
253000
254000
14300 tensor(5.2459, device='mps:0', grad_fn=<NllLossBackward0>)
255000
14400 tensor(5.1297, device='mps:0', grad_fn=<NllLossBackward0>)
256000
257000
14500 tensor(5.0971, device='mps:0', grad_fn=<NllLossBackward0>)
258000
259000
14600 tensor(5.2238, device='mps:0', grad_fn=<NllLossBackward0>)
260000
261000
14700 tensor(5.2328, device='mps:0', grad_fn=<NllLossBackward0>)
262000
263000
14800 tensor(5.1782, device='mps:0', grad_fn=<NllLossBackward0>)
264000
14900 tensor(5.3230, device='mps:0', grad_fn=<NllLossBackward0>)
265000
266000
15000 tensor(5.1504, device='mps:0', grad_fn=<NllLossBackward0>)
267000
268000
15100 tensor(5.1998, device='mps:0', grad_fn=<NllLossBackward0>)
269000
270000
15200 tensor(5.2138, device='mps:0', grad_fn=<NllLossBackward0>)
271000
272000
15300 tensor(5.4110, device='mps:0', grad_fn=<NllLossBackward0>)
273000
15400 tensor(5.1748, device='mps:0', grad_fn=<NllLossBackward0>)
274000
275000
15500 tensor(5.2118, device='mps:0', grad_fn=<NllLossBackward0>)
276000
277000
15600 tensor(5.2297, device='mps:0', grad_fn=<NllLossBackward0>)
278000
279000
15700 tensor(5.2977, device='mps:0', grad_fn=<NllLossBackward0>)
280000
15800 tensor(5.2175, device='mps:0', grad_fn=<NllLossBackward0>)
281000
282000
15900 tensor(5.0613, device='mps:0', grad_fn=<NllLossBackward0>)
283000
284000
16000 tensor(5.0862, device='mps:0', grad_fn=<NllLossBackward0>)
285000
286000
16100 tensor(5.1910, device='mps:0', grad_fn=<NllLossBackward0>)
287000
288000
16200 tensor(5.0195, device='mps:0', grad_fn=<NllLossBackward0>)
289000
16300 tensor(5.1381, device='mps:0', grad_fn=<NllLossBackward0>)
290000
291000
16400 tensor(5.2135, device='mps:0', grad_fn=<NllLossBackward0>)
292000
293000
16500 tensor(5.2058, device='mps:0', grad_fn=<NllLossBackward0>)
294000
295000
16600 tensor(5.2372, device='mps:0', grad_fn=<NllLossBackward0>)
296000
16700 tensor(5.1753, device='mps:0', grad_fn=<NllLossBackward0>)
297000
298000
16800 tensor(5.0765, device='mps:0', grad_fn=<NllLossBackward0>)
299000
300000
16900 tensor(5.3361, device='mps:0', grad_fn=<NllLossBackward0>)
301000
302000
17000 tensor(5.2745, device='mps:0', grad_fn=<NllLossBackward0>)
303000
304000
17100 tensor(5.2249, device='mps:0', grad_fn=<NllLossBackward0>)
305000
17200 tensor(5.1877, device='mps:0', grad_fn=<NllLossBackward0>)
306000
307000
17300 tensor(5.0891, device='mps:0', grad_fn=<NllLossBackward0>)
308000
309000
17400 tensor(5.4181, device='mps:0', grad_fn=<NllLossBackward0>)
310000
311000
17500 tensor(5.1299, device='mps:0', grad_fn=<NllLossBackward0>)
312000
17600 tensor(5.1636, device='mps:0', grad_fn=<NllLossBackward0>)
313000
314000
17700 tensor(5.2179, device='mps:0', grad_fn=<NllLossBackward0>)
315000
316000
17800 tensor(5.2689, device='mps:0', grad_fn=<NllLossBackward0>)
317000
318000
17900 tensor(5.2410, device='mps:0', grad_fn=<NllLossBackward0>)
319000
320000
18000 tensor(5.2342, device='mps:0', grad_fn=<NllLossBackward0>)
321000
18100 tensor(5.2234, device='mps:0', grad_fn=<NllLossBackward0>)
322000
323000
18200 tensor(5.0779, device='mps:0', grad_fn=<NllLossBackward0>)
324000
325000
18300 tensor(5.2378, device='mps:0', grad_fn=<NllLossBackward0>)
326000
327000
18400 tensor(5.1710, device='mps:0', grad_fn=<NllLossBackward0>)
328000
18500 tensor(5.1134, device='mps:0', grad_fn=<NllLossBackward0>)
329000
330000
18600 tensor(5.2679, device='mps:0', grad_fn=<NllLossBackward0>)
331000
332000
18700 tensor(5.2590, device='mps:0', grad_fn=<NllLossBackward0>)
333000
334000
18800 tensor(5.1842, device='mps:0', grad_fn=<NllLossBackward0>)
335000
18900 tensor(5.1379, device='mps:0', grad_fn=<NllLossBackward0>)
336000
337000
19000 tensor(5.1416, device='mps:0', grad_fn=<NllLossBackward0>)
338000
339000
19100 tensor(5.1602, device='mps:0', grad_fn=<NllLossBackward0>)
340000
341000
19200 tensor(5.2670, device='mps:0', grad_fn=<NllLossBackward0>)
342000
343000
19300 tensor(5.1622, device='mps:0', grad_fn=<NllLossBackward0>)
344000
19400 tensor(5.1805, device='mps:0', grad_fn=<NllLossBackward0>)
345000
346000
19500 tensor(5.1820, device='mps:0', grad_fn=<NllLossBackward0>)
347000
348000
19600 tensor(5.2506, device='mps:0', grad_fn=<NllLossBackward0>)
349000
350000
19700 tensor(5.1566, device='mps:0', grad_fn=<NllLossBackward0>)
351000
352000
19800 tensor(5.1121, device='mps:0', grad_fn=<NllLossBackward0>)
353000
19900 tensor(5.1227, device='mps:0', grad_fn=<NllLossBackward0>)
354000
355000
20000 tensor(5.2132, device='mps:0', grad_fn=<NllLossBackward0>)
356000
357000
20100 tensor(5.2681, device='mps:0', grad_fn=<NllLossBackward0>)
358000
359000
20200 tensor(5.2689, device='mps:0', grad_fn=<NllLossBackward0>)
360000
20300 tensor(5.1758, device='mps:0', grad_fn=<NllLossBackward0>)
361000
362000
20400 tensor(5.1275, device='mps:0', grad_fn=<NllLossBackward0>)
363000
364000
20500 tensor(5.1803, device='mps:0', grad_fn=<NllLossBackward0>)
365000
366000
20600 tensor(5.1202, device='mps:0', grad_fn=<NllLossBackward0>)
367000
368000
20700 tensor(5.2343, device='mps:0', grad_fn=<NllLossBackward0>)
369000
20800 tensor(5.2035, device='mps:0', grad_fn=<NllLossBackward0>)
370000
371000
20900 tensor(5.2992, device='mps:0', grad_fn=<NllLossBackward0>)
372000
373000
21000 tensor(5.1540, device='mps:0', grad_fn=<NllLossBackward0>)
374000
375000
21100 tensor(5.2739, device='mps:0', grad_fn=<NllLossBackward0>)
376000
21200 tensor(5.2949, device='mps:0', grad_fn=<NllLossBackward0>)
377000
378000
21300 tensor(5.2138, device='mps:0', grad_fn=<NllLossBackward0>)
379000
380000
21400 tensor(5.2773, device='mps:0', grad_fn=<NllLossBackward0>)
381000
382000
21500 tensor(5.2345, device='mps:0', grad_fn=<NllLossBackward0>)
383000
21600 tensor(5.2528, device='mps:0', grad_fn=<NllLossBackward0>)
384000
385000
21700 tensor(5.1824, device='mps:0', grad_fn=<NllLossBackward0>)
386000
387000
21800 tensor(5.1943, device='mps:0', grad_fn=<NllLossBackward0>)
388000
389000
21900 tensor(5.0359, device='mps:0', grad_fn=<NllLossBackward0>)
390000
391000
22000 tensor(5.1506, device='mps:0', grad_fn=<NllLossBackward0>)
392000
22100 tensor(5.1253, device='mps:0', grad_fn=<NllLossBackward0>)
393000
394000
22200 tensor(5.0982, device='mps:0', grad_fn=<NllLossBackward0>)
395000
396000
22300 tensor(5.1554, device='mps:0', grad_fn=<NllLossBackward0>)
397000
398000
22400 tensor(5.1673, device='mps:0', grad_fn=<NllLossBackward0>)
399000
22500 tensor(5.1957, device='mps:0', grad_fn=<NllLossBackward0>)
400000
401000
22600 tensor(5.1328, device='mps:0', grad_fn=<NllLossBackward0>)
402000
403000
22700 tensor(5.2231, device='mps:0', grad_fn=<NllLossBackward0>)
404000
405000
22800 tensor(5.1370, device='mps:0', grad_fn=<NllLossBackward0>)
406000
407000
22900 tensor(5.2334, device='mps:0', grad_fn=<NllLossBackward0>)
408000
23000 tensor(5.1372, device='mps:0', grad_fn=<NllLossBackward0>)
409000
410000
23100 tensor(5.1193, device='mps:0', grad_fn=<NllLossBackward0>)
411000
412000
23200 tensor(5.2649, device='mps:0', grad_fn=<NllLossBackward0>)
413000
414000
23300 tensor(5.1514, device='mps:0', grad_fn=<NllLossBackward0>)
415000
23400 tensor(5.2532, device='mps:0', grad_fn=<NllLossBackward0>)
416000
417000
23500 tensor(5.3751, device='mps:0', grad_fn=<NllLossBackward0>)
418000
419000
23600 tensor(5.0766, device='mps:0', grad_fn=<NllLossBackward0>)
420000
421000
23700 tensor(5.0915, device='mps:0', grad_fn=<NllLossBackward0>)
422000
423000
23800 tensor(5.3195, device='mps:0', grad_fn=<NllLossBackward0>)
424000
23900 tensor(5.2758, device='mps:0', grad_fn=<NllLossBackward0>)
425000
426000
24000 tensor(5.0487, device='mps:0', grad_fn=<NllLossBackward0>)
427000
428000
24100 tensor(5.1555, device='mps:0', grad_fn=<NllLossBackward0>)
429000
430000
24200 tensor(5.2140, device='mps:0', grad_fn=<NllLossBackward0>)
431000
24300 tensor(5.2729, device='mps:0', grad_fn=<NllLossBackward0>)
432000

LOAD MODEL AND VOCAB

with open(path_to_vocabulary_file, 'rb') as handle:
        vocab = pickle.load(handle)
model = SimpleBigramNeuralLanguageModel(vocab_size, embed_size).to(device)
model.load_state_dict(torch.load(path_to_model_file))
model.eval()
SimpleBigramNeuralLanguageModel(
  (model): Sequential(
    (0): Embedding(30000, 1000)
    (1): Linear(in_features=1000, out_features=30000, bias=True)
    (2): Softmax(dim=None)
  )
)

CREATE OUTPUTS FILES

DEV-0

predicition_for_file(model, vocab, folder_dev_0, file_dev_0)
==========  do prediction for dev-0/in.tsv.xz  ==========

TEST-A

predicition_for_file(model, vocab, folder_test_a, file_test_a)
==========  do prediction for test-A/in.tsv.xz  ==========