In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

torch.manual_seed(1)




### Data loading

In [3]:
import pickle
import lzma
import regex as re


def load_pickle(filename):
 with open(filename, "rb") as f:
 return pickle.load(f)


def save_pickle(d):
 with open("vocabulary.pkl", "wb") as f:
 pickle.dump(d, f, protocol=pickle.HIGHEST_PROTOCOL)


def clean_document(document: str) -> str:
 document = document.lower().replace("’", "'")
 document = re.sub(r"'s|[\-­]\\n", "", document)
 document = re.sub(
 r"(\\+n|[{}\[\]”&:•¦()*0-9;\"«»$\-><^,®¬¿?¡!#+. \t\n])+", " ", document
 )
 for to_find, substitute in zip(
 ["i'm", "won't", "n't", "'ll"], ["i am", "will not", " not", " will"]
 ):
 document = document.replace(to_find, substitute)
 return document


def get_words_from_line(line, clean_text=True):
 if clean_text:
 line = clean_document(line) # .rstrip()
 else:
 line = line.strip()
 yield ""
 for m in re.finditer(r"[\p{L}0-9\*]+|\p{P}+", line):
 yield m.group(0).lower()
 yield ""


def get_word_lines_from_file(file_name, clean_text=True, only_text=False):
 with lzma.open(file_name, "r") as fh:
 for i, line in enumerate(fh):
 if only_text:
 line = "\t".join(line.decode("utf-8").split("\t")[:-2])
 else:
 line = line.decode("utf-8")
 if i % 10000 == 0:
 print(i)
 yield get_words_from_line(line, clean_text)


### Dataclasses

In [5]:
from torch.utils.data import IterableDataset
from torchtext.vocab import build_vocab_from_iterator
import itertools


VOCAB_SIZE = 20000


def look_ahead_iterator(gen):
 prev = None
 for item in gen:
 if prev is not None:
 yield (prev, item)
 prev = item


class Bigrams(IterableDataset):
 def __init__(
 self, text_file, vocabulary_size, vocab=None, only_text=False, clean_text=True
 ):
 self.vocab = (
 build_vocab_from_iterator(
 get_word_lines_from_file(text_file, clean_text, only_text),
 max_tokens=vocabulary_size,
 specials=[""],
 )
 if vocab is None
 else vocab
 )
 self.vocab.set_default_index(self.vocab[""])
 self.vocabulary_size = vocabulary_size
 self.text_file = text_file
 self.clean_text = clean_text
 self.only_text = only_text

 def __iter__(self):
 return look_ahead_iterator(
 (
 self.vocab[t]
 for t in itertools.chain.from_iterable(
 get_word_lines_from_file(
 self.text_file, self.clean_text, self.only_text
 )
 )
 )
 )


vocab = None # torch.load('./vocab.pth')

train_dataset = Bigrams("/content/train/in.tsv.xz", VOCAB_SIZE, vocab, clean_text=False)


0
10000
20000
30000
40000
50000
60000
70000
80000
90000
100000
110000
120000
130000
140000
150000
160000
170000
180000
190000
200000
210000
220000
230000
240000
250000
260000
270000
280000
290000
300000
310000
320000
330000
340000
350000
360000
370000
380000
390000
400000
410000
420000
430000


In [6]:
# torch.save(train_dataset.vocab, "vocab.pth")
# torch.save(train_dataset.vocab, "vocab_only_text.pth")
# torch.save(train_dataset.vocab, "vocab_only_text_clean.pth")
torch.save(train_dataset.vocab, "vocab_2.pth")


### Model definition

In [7]:
class SimpleBigramNeuralLanguageModel(nn.Module):
 def __init__(self, vocabulary_size, embedding_size):
 super(SimpleBigramNeuralLanguageModel, self).__init__()
 self.model = nn.Sequential(
 nn.Embedding(vocabulary_size, embedding_size),
 nn.Linear(embedding_size, vocabulary_size),
 nn.Softmax(),
 )

 def forward(self, x):
 return self.model(x)


In [8]:
EMBED_SIZE = 100

model = SimpleBigramNeuralLanguageModel(VOCAB_SIZE, EMBED_SIZE)


In [9]:
from torch.utils.data import DataLoader

device = "cuda"
model = SimpleBigramNeuralLanguageModel(VOCAB_SIZE, EMBED_SIZE).to(device)
data = DataLoader(train_dataset, batch_size=5000)
optimizer = torch.optim.Adam(model.parameters())
criterion = torch.nn.NLLLoss()

model.train()
step = 0
for x, y in data:
 x = x.to(device)
 y = y.to(device)
 optimizer.zero_grad()
 ypredicted = model(x)
 loss = criterion(torch.log(ypredicted), y)
 if step % 100 == 0:
 print(step, loss)
 step += 1
 loss.backward()
 optimizer.step()

torch.save(model.state_dict(), "model_2.bin")


0


 input = module(input)


0 tensor(10.0674, device='cuda:0', grad_fn=)
100 tensor(8.4352, device='cuda:0', grad_fn=)
200 tensor(7.6662, device='cuda:0', grad_fn=)
300 tensor(7.0716, device='cuda:0', grad_fn=)
400 tensor(6.6710, device='cuda:0', grad_fn=)
500 tensor(6.4540, device='cuda:0', grad_fn=)
600 tensor(5.9974, device='cuda:0', grad_fn=)
700 tensor(5.7973, device='cuda:0', grad_fn=)
800 tensor(5.8026, device='cuda:0', grad_fn=)
10000
900 tensor(5.7118, device='cuda:0', grad_fn=)
1000 tensor(5.7471, device='cuda:0', grad_fn=)
1100 tensor(5.6865, device='cuda:0', grad_fn=)
1200 tensor(5.4205, device='cuda:0', grad_fn=)
1300 tensor(5.4954, device='cuda:0', grad_fn=)
1400 tensor(5.5415, device='cuda:0', grad_fn=)
1500 tensor(5.3322, device='cuda:0', grad_fn=)
1600 tensor(5.4665, device='cuda:0', grad_fn=)
1700 tensor(5.4710, device='cuda:0', grad_fn=)
20000
1800 tensor(5.3953, device='cuda:0', grad_fn=)
1900 tensor(5.4881, device='cuda:0', grad_fn=)
2000 tensor(5.4915, device='cuda:0', grad_fn=)
2100 tensor(

KeyboardInterrupt: ignored

In [10]:
ixs = torch.tensor(train_dataset.vocab.forward(["when"])).to(device)
out = model(ixs)
top = torch.topk(out[0], 10)
top_indices = top.indices.tolist()
top_probs = top.values.tolist()
top_words = train_dataset.vocab.lookup_tokens(top_indices)
list(zip(top_words, top_indices, top_probs))


[('the', 2, 0.15899169445037842),
 ('\\', 1, 0.10546761751174927),
 ('he', 28, 0.06849857419729233),
 ('it', 15, 0.05329886078834534),
 ('i', 26, 0.0421920120716095),
 ('they', 50, 0.03895237296819687),
 ('a', 8, 0.03352600708603859),
 ('', 0, 0.031062396243214607),
 ('we', 61, 0.02323235757648945),
 ('she', 104, 0.02003088779747486)]

In [13]:
device = "cuda"
model = SimpleBigramNeuralLanguageModel(VOCAB_SIZE, EMBED_SIZE).to(device)
model.load_state_dict(torch.load("model1.bin"))




In [14]:
def predict_word(ixs, model, top_k=5):
 out = model(ixs)
 top = torch.topk(out[0], 10)
 top_indices = top.indices.tolist()
 top_probs = top.values.tolist()
 top_words = train_dataset.vocab.lookup_tokens(top_indices)
 return list(zip(top_words, top_indices, top_probs))


def get_one_word(text, context="left"):
 # print("Getting word from:", text)
 if context == "left":
 context = -1
 else:
 context = 0
 return text.rstrip().split(" ")[context]


def inference_on_file(filename, model, lines_no=1):
 results_path = "/".join(filename.split("/")[:-1]) + "/out.tsv"
 with lzma.open(filename, "r") as fp, open(results_path, "w") as out_file:
 print("Training on", filename)
 for i, line in enumerate(fp):
 # left, right = [ get_one_word(text_part, context)
 # for context, text_part in zip(line.split('\t')[:-2], ('left', 'right'))]
 line = line.decode("utf-8")
 # print(line)
 left = get_one_word(line.split("\t")[-2])
 # print("Current word:", left)
 tensor = torch.tensor(train_dataset.vocab.forward([left])).to(device)
 results = predict_word(tensor, model, 9)
 prob_sum = sum([word[2] for word in results])
 result_line = (
 " ".join([f"{word[0]}:{word[2]}" for word in results])
 + f" :{prob_sum}\n"
 )
 # print(result_line)
 out_file.write(result_line)
 print(f"\rProgress: {(((i+1) / lines_no) * 100):.2f}%", end="")
 print()


model.eval()

for filepath, lines_no in zip(
 ("/content/dev-0/in.tsv.xz", "/content/test-A/in.tsv.xz"), (10519.0, 7414.0)
):
 inference_on_file(filepath, model, lines_no)


Training on /content/dev-0/in.tsv.xz
Progress: 0.01%Progress: 0.02%Progress: 0.03%Progress: 0.04%Progress: 0.05%Progress: 0.06%Progress: 0.07%Progress: 0.08%Progress: 0.09%Progress: 0.10%Progress: 0.10%Progress: 0.11%Progress: 0.12%Progress: 0.13%Progress: 0.14%Progress: 0.15%Progress: 0.16%Progress: 0.17%Progress: 0.18%Progress: 0.19%Progress: 0.20%Progress: 0.21%Progress: 0.22%Progress: 0.23%Progress: 0.24%Progress: 0.25%Progress: 0.26%Progress: 0.27%Progress: 0.28%Progress: 0.29%Progress: 0.29%Progress: 0.30%Progress: 0.31%Progress: 0.32%Progress: 0.33%Progress: 0.34%Progress: 0.35%Progress: 0.36%Progress: 0.37%Progress: 0.38%Progress: 0.39%Progress: 0.40%Progress: 0.41%Progress: 0.42%Progress: 0.43%Progress: 0.44%Progress: 0.45%Progress: 0.46%Progress: 0.47%Progress: 0.48%Progress: 0.48%Progress: 0.49%Progress: 0.50%Progress: 0.51%Progress: 0.52%Progress: 0.53%Progress: 0.54%Progress: 0.55%Progress: 0.56%Progress: 0.57%Pr

 input = module(input)


Progress: 100.00%
Training on /content/test-A/in.tsv.xz
Progress: 100.00%
