In [1]:
import torch
import lzma
from itertools import islice
import regex as re
import sys
from torchtext.vocab import build_vocab_from_iterator
from torch import nn
from torch.utils.data import IterableDataset
import itertools

In [None]:
"values,index  =prob_dist.topk(20) 
for x,y in (values,index):
  print(x,y)"

In [2]:
# with lzma.open("train/in.tsv.xz", encoding='utf8', mode="rt") as fh:
#     for line in fh:
#         # print(line)
#         pattern = r'\^\^|\n|\\|[<>]|[()]'
#         line = re.sub(pattern, '', line)
#         print(line)

In [3]:
def get_words_from_line(line):
  line = line.rstrip()
  yield '<s>'
  for t in line.split():
    yield t
  yield '</s>'


def get_word_lines_from_file(file_name):
  with lzma.open(file_name, encoding='utf8', mode="rt") as fh:
    for line in fh:
      pattern = r'\^\^|\n|\\|[<>]|[()]'
      line = re.sub(pattern, '', line)
      yield get_words_from_line(line)

vocab_size = 2500

vocab = build_vocab_from_iterator(
    get_word_lines_from_file("train/in.tsv.xz"),
    max_tokens = vocab_size,
    specials = ['<unk>'])



In [4]:
def look_ahead_iterator(gen):
   prev = None
   for item in gen:
      if prev is not None:
         yield (prev, item)
      prev = item

class Bigrams(IterableDataset):
  def __init__(self, text_file, vocabulary_size):
      self.vocab = vocab
      self.vocab.set_default_index(self.vocab['<unk>'])
      self.vocabulary_size = vocabulary_size
      self.text_file = text_file

  def __iter__(self):
     return look_ahead_iterator(
         (self.vocab[t] for t in itertools.chain.from_iterable(get_word_lines_from_file(self.text_file))))

train_dataset = Bigrams("train/in.tsv.xz", vocab_size)

In [5]:
from torch.utils.data import DataLoader

next(iter(train_dataset))

(24, 0)

In [6]:
next(iter(DataLoader(train_dataset, batch_size=5)))

[tensor([  24,    0, 1021,   25,    0]),
 tensor([   0, 1021,   25,    0,    0])]

In [7]:
embed_size = 200

class SimpleBigramNeuralLanguageModel(nn.Module):
  def __init__(self, vocabulary_size, embedding_size):
      super(SimpleBigramNeuralLanguageModel, self).__init__()
      self.model = nn.Sequential(
          nn.Embedding(vocabulary_size, embedding_size),
          nn.Linear(embedding_size, vocabulary_size),
          nn.Softmax()
      )

  def forward(self, x):
      return self.model(x)

model = SimpleBigramNeuralLanguageModel(vocab_size, embed_size)

vocab.set_default_index(vocab['<unk>'])
ixs = torch.tensor(vocab.forward(['is']))
out = model(ixs)
out[0][vocab['is']]

  input = module(input)


tensor(0.0002, grad_fn=<SelectBackward0>)

In [8]:
device = 'cpu'
model = SimpleBigramNeuralLanguageModel(vocab_size, embed_size).to(device)
data = DataLoader(train_dataset, batch_size=6000)
optimizer = torch.optim.Adam(model.parameters())
criterion = torch.nn.NLLLoss()

model.train()
step = 0
for x, y in data:
   x = x.to(device)
   y = y.to(device)
   optimizer.zero_grad()
   ypredicted = model(x)
   loss = criterion(torch.log(ypredicted), y)
   if step % 100 == 0:
      print(step, loss)
   step += 1
   loss.backward()
   optimizer.step()


0 tensor(8.0177, grad_fn=<NllLossBackward0>)
100 tensor(5.2132, grad_fn=<NllLossBackward0>)


KeyboardInterrupt: 

In [9]:
torch.save(model.state_dict(), 'model1.bin')

In [15]:
device = 'cpu'
model = SimpleBigramNeuralLanguageModel(vocab_size, embed_size).to(device)
model.load_state_dict(torch.load('model1.bin'))
model.eval()

ixs = torch.tensor(vocab.forward(['he'])).to(device)
print(ixs)

out = model(ixs)
top = torch.topk(out[0], 10)
top_indices = top.indices.tolist()
top_probs = top.values.tolist()
top_words = vocab.lookup_tokens(top_indices)
print(top_words,'\n',top_indices,'\n',top_probs)
# list(zip(top_words, top_indices, top_probs))

tensor([19])
['had', 'was', 'has', 'did', 'would', 'went', 'could', 'said,', 'be', 'must'] 
 [36, 10, 37, 131, 49, 203, 99, 788, 11, 127] 
 [0.052415624260902405, 0.0358351431787014, 0.027370842173695564, 0.01275723148137331, 0.010935396887362003, 0.009363135322928429, 0.009096388705074787, 0.0070375604555010796, 0.005759422201663256, 0.005631867330521345]


In [16]:
def prediction(word: str) -> str:
    ixs = torch.tensor(vocab.forward([word])).to(device)
    out = model(ixs)
    top = torch.topk(out[0], 5)
    top_indices = top.indices.tolist()
    top_probs = top.values.tolist()
    top_words = vocab.lookup_tokens(top_indices)
    zipped = list(zip(top_words, top_probs))
    for index, element in enumerate(zipped):
        unk = None
        if '<unk>' in element:
            unk = zipped.pop(index)
            zipped.append(('', unk[1]))
            break
    if unk is None:
        zipped[-1] = ('', zipped[-1][1])
    print(' '.join([f'{x[0]}:{x[1]}' for x in zipped]))
    return ' '.join([f'{x[0]}:{x[1]}' for x in zipped])

In [None]:
prompt = 'Think about'
max_seq_len = 30
seed = 0

In [17]:
def create_outputs(folder_name):
    print(f'Creating outputs in {folder_name}')
    with lzma.open(f'{folder_name}/in.tsv.xz', mode='rt', encoding='utf-8') as fid:
        with open(f'{folder_name}/out.tsv', 'w', encoding='utf-8', newline='\n') as f:
            for line in fid:
                separated = line.split('\t')
                prefix = separated[6].replace(r'\n', ' ').split()[-1]
                output_line = prediction(prefix)
                f.write(output_line + '\n')

In [18]:
create_outputs('dev-0')
create_outputs('test-A')

Creating outputs in dev-0
the:0.01148879062384367 of:0.006031177006661892 a:0.005755534861236811 his:0.004050770308822393 :0.009219370782375336
the:0.023308811709284782 a:0.02173895761370659 own:0.012746201828122139 that:0.007281431928277016 :0.02100181020796299
the:0.02363738790154457 that:0.008539832197129726 a:0.008165501989424229 an:0.007165413349866867 :0.035898447036743164
of:0.05411530286073685 and:0.04940299689769745 the:0.0400581881403923 to:0.029938215389847755 :0.37532946467399597
of:0.05411530286073685 and:0.04940299689769745 the:0.0400581881403923 to:0.029938215389847755 :0.37532946467399597
a:0.07568275183439255 the:0.045853495597839355 his:0.02007143944501877 all:0.013846635818481445 :0.04955475032329559
the:0.00316080660559237 all:0.002639115322381258 them.:0.0024603919591754675 as:0.002170901047065854 :0.0020744542125612497
of:0.05411530286073685 and:0.04940299689769745 the:0.0400581881403923 to:0.029938215389847755 :0.37532946467399597
of:0.05411530286073685 and:0.049

  input = module(input)


the:0.23462098836898804 a:0.026843780651688576 this:0.012386118993163109 said:0.010683613829314709 :0.45802026987075806
of:0.05411530286073685 and:0.04940299689769745 the:0.0400581881403923 to:0.029938215389847755 :0.37532946467399597
any:0.0173477903008461 three:0.011614865623414516 other:0.009230944328010082 more:0.008342713117599487 :0.2566910684108734
a:0.007183287292718887 was:0.005211757030338049 not:0.004996122792363167 being:0.0044728415086865425 :0.02697567455470562
the:0.23462098836898804 a:0.026843780651688576 this:0.012386118993163109 said:0.010683613829314709 :0.45802026987075806
same:0.006965374108403921 said:0.006106619723141193 United:0.005009945016354322 first:0.00498473085463047 :0.6302009224891663
an:0.004460567608475685 their:0.004396467469632626 the:0.0036188296508044004 will:0.0034394278191030025 :0.010485290549695492
MORNING:0.0032735588029026985 cent,:0.0024784752167761326 own:0.0021904774475842714 provided,:0.001914860913529992 :0.0017092936905100942
of:0.01909

KeyboardInterrupt: 