Bigram neural finish.

This commit is contained in:
Jan Nowak 2022-05-07 14:53:24 +02:00
parent fd03c9369f
commit 9c381a9eea
3 changed files with 17979 additions and 17942 deletions

File diff suppressed because it is too large Load Diff

55
run.py
View File

@ -9,6 +9,8 @@ from torch.utils.data import IterableDataset
import itertools import itertools
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
import numpy as np import numpy as np
from nltk.tokenize import RegexpTokenizer
from nltk import trigrams
# def get_words_from_line(file_path): # def get_words_from_line(file_path):
@ -20,6 +22,13 @@ import numpy as np
# if index == 10000: # if index == 10000:
# break # break
tokenizer = RegexpTokenizer(r"\w+")
def read_file_6(file):
for line in file:
text = line.split("\t")
yield re.sub(r"[^\w\d'\s]+", '', re.sub(' +', ' ', text[6].replace("\\n", " ").replace("\n", "").lower()))
def get_words_from_line(line): def get_words_from_line(line):
line = line.rstrip() line = line.rstrip()
@ -34,11 +43,11 @@ def get_words_lines_from_file(file_path):
for index, line in enumerate(file): for index, line in enumerate(file):
text = line.split("\t") text = line.split("\t")
yield get_words_from_line(re.sub(r"[^\w\d'\s]+", '', re.sub(' +', ' ', ' '.join([text[6], text[7]]).replace("\\n", " ").replace("\n", "").lower()))) yield get_words_from_line(re.sub(r"[^\w\d'\s]+", '', re.sub(' +', ' ', ' '.join([text[6], text[7]]).replace("\\n", " ").replace("\n", "").lower())))
if index == 50000: # if index == 1000:
break # break
vocab_size = 20000 vocab_size = 30000
vocab = build_vocab_from_iterator( vocab = build_vocab_from_iterator(
get_words_lines_from_file('train/in.tsv.xz'), get_words_lines_from_file('train/in.tsv.xz'),
@ -88,7 +97,7 @@ class Bigrams(IterableDataset):
def train(): def train():
batch_size = 22000 batch_size = 15000
train_dataset = Bigrams('train/in.tsv.xz', vocab_size) train_dataset = Bigrams('train/in.tsv.xz', vocab_size)
@ -117,23 +126,32 @@ def train():
loss.backward() loss.backward()
# Update Weights # Update Weights
optimizer.step() optimizer.step()
print(step)
torch.save(model.state_dict(), 'model1.bin') torch.save(model.state_dict(), 'model1.bin')
def predict(): def predict(word):
device = 'cuda' device = 'cuda'
model = SimpleBigramNeuralLanguageModel(vocab_size, embed_size).to(device) model = SimpleBigramNeuralLanguageModel(vocab_size, embed_size).to(device)
model.load_state_dict(torch.load('model1.bin')) model.load_state_dict(torch.load('model1.bin'))
model.eval() model.eval()
ixs = torch.tensor(vocab.forward(['for'])).to(device) ixs = torch.tensor(vocab.forward([word])).to(device)
out = model(ixs) out = model(ixs)
top = torch.topk(out[0], 10) top = torch.topk(out[0], 8)
top_indices = top.indices.tolist() top_indices = top.indices.tolist()
top_probs = top.values.tolist() top_probs = top.values.tolist()
top_words = vocab.lookup_tokens(top_indices) top_words = vocab.lookup_tokens(top_indices)
print(list(zip(top_words, top_indices, top_probs))) str_predictions = ""
lht = 1.0
for pred_word in list(zip(top_words, top_indices, top_probs)):
if lht - pred_word[2] >= 0:
str_predictions += f"{pred_word[0]}:{pred_word[2]} "
lht -= pred_word[2]
if lht != 1.0:
str_predictions += f":{lht}"
return str_predictions
def similar(): def similar():
@ -158,6 +176,25 @@ def similar():
print(list(zip(top_words, top_indices, top_probs))) print(list(zip(top_words, top_indices, top_probs)))
def generate_outputs(input_file, output_file):
with open(output_file, 'w') as outputf:
with lzma.open(input_file, mode='rt') as file:
for index, text in enumerate(read_file_6(file)):
tokens = tokenizer.tokenize(text)
if len(tokens) < 4:
prediction = 'the:0.2 be:0.2 to:0.2 of:0.1 and:0.1 a:0.1 :0.1'
else:
prediction = predict(tokens[-1])
outputf.write(prediction + '\n')
if __name__ == "__main__": if __name__ == "__main__":
# train() # train()
predict() # predict()
# generate_outputs("dev-0/in.tsv.xz", "dev-0/out.tsv")
generate_outputs("test-A/in.tsv.xz", "test-A/out.tsv")
# count_words = 0
# for i in get_words_lines_from_file('train/in.tsv.xz'):
# for j in i:
# count_words += 1
# print(count_words)

File diff suppressed because it is too large Load Diff