Bigram neural finish.
This commit is contained in:
parent
fd03c9369f
commit
9c381a9eea
21038
dev-0/out.tsv
21038
dev-0/out.tsv
File diff suppressed because it is too large
Load Diff
55
run.py
55
run.py
@ -9,6 +9,8 @@ from torch.utils.data import IterableDataset
|
|||||||
import itertools
|
import itertools
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from nltk.tokenize import RegexpTokenizer
|
||||||
|
from nltk import trigrams
|
||||||
|
|
||||||
|
|
||||||
# def get_words_from_line(file_path):
|
# def get_words_from_line(file_path):
|
||||||
@ -20,6 +22,13 @@ import numpy as np
|
|||||||
# if index == 10000:
|
# if index == 10000:
|
||||||
# break
|
# break
|
||||||
|
|
||||||
|
tokenizer = RegexpTokenizer(r"\w+")
|
||||||
|
|
||||||
|
def read_file_6(file):
|
||||||
|
for line in file:
|
||||||
|
text = line.split("\t")
|
||||||
|
yield re.sub(r"[^\w\d'\s]+", '', re.sub(' +', ' ', text[6].replace("\\n", " ").replace("\n", "").lower()))
|
||||||
|
|
||||||
|
|
||||||
def get_words_from_line(line):
|
def get_words_from_line(line):
|
||||||
line = line.rstrip()
|
line = line.rstrip()
|
||||||
@ -34,11 +43,11 @@ def get_words_lines_from_file(file_path):
|
|||||||
for index, line in enumerate(file):
|
for index, line in enumerate(file):
|
||||||
text = line.split("\t")
|
text = line.split("\t")
|
||||||
yield get_words_from_line(re.sub(r"[^\w\d'\s]+", '', re.sub(' +', ' ', ' '.join([text[6], text[7]]).replace("\\n", " ").replace("\n", "").lower())))
|
yield get_words_from_line(re.sub(r"[^\w\d'\s]+", '', re.sub(' +', ' ', ' '.join([text[6], text[7]]).replace("\\n", " ").replace("\n", "").lower())))
|
||||||
if index == 50000:
|
# if index == 1000:
|
||||||
break
|
# break
|
||||||
|
|
||||||
|
|
||||||
vocab_size = 20000
|
vocab_size = 30000
|
||||||
|
|
||||||
vocab = build_vocab_from_iterator(
|
vocab = build_vocab_from_iterator(
|
||||||
get_words_lines_from_file('train/in.tsv.xz'),
|
get_words_lines_from_file('train/in.tsv.xz'),
|
||||||
@ -88,7 +97,7 @@ class Bigrams(IterableDataset):
|
|||||||
|
|
||||||
|
|
||||||
def train():
|
def train():
|
||||||
batch_size = 22000
|
batch_size = 15000
|
||||||
|
|
||||||
train_dataset = Bigrams('train/in.tsv.xz', vocab_size)
|
train_dataset = Bigrams('train/in.tsv.xz', vocab_size)
|
||||||
|
|
||||||
@ -117,23 +126,32 @@ def train():
|
|||||||
loss.backward()
|
loss.backward()
|
||||||
# Update Weights
|
# Update Weights
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
print(step)
|
||||||
torch.save(model.state_dict(), 'model1.bin')
|
torch.save(model.state_dict(), 'model1.bin')
|
||||||
|
|
||||||
|
|
||||||
def predict():
|
def predict(word):
|
||||||
device = 'cuda'
|
device = 'cuda'
|
||||||
model = SimpleBigramNeuralLanguageModel(vocab_size, embed_size).to(device)
|
model = SimpleBigramNeuralLanguageModel(vocab_size, embed_size).to(device)
|
||||||
model.load_state_dict(torch.load('model1.bin'))
|
model.load_state_dict(torch.load('model1.bin'))
|
||||||
model.eval()
|
model.eval()
|
||||||
|
|
||||||
ixs = torch.tensor(vocab.forward(['for'])).to(device)
|
ixs = torch.tensor(vocab.forward([word])).to(device)
|
||||||
|
|
||||||
out = model(ixs)
|
out = model(ixs)
|
||||||
top = torch.topk(out[0], 10)
|
top = torch.topk(out[0], 8)
|
||||||
top_indices = top.indices.tolist()
|
top_indices = top.indices.tolist()
|
||||||
top_probs = top.values.tolist()
|
top_probs = top.values.tolist()
|
||||||
top_words = vocab.lookup_tokens(top_indices)
|
top_words = vocab.lookup_tokens(top_indices)
|
||||||
print(list(zip(top_words, top_indices, top_probs)))
|
str_predictions = ""
|
||||||
|
lht = 1.0
|
||||||
|
for pred_word in list(zip(top_words, top_indices, top_probs)):
|
||||||
|
if lht - pred_word[2] >= 0:
|
||||||
|
str_predictions += f"{pred_word[0]}:{pred_word[2]} "
|
||||||
|
lht -= pred_word[2]
|
||||||
|
if lht != 1.0:
|
||||||
|
str_predictions += f":{lht}"
|
||||||
|
return str_predictions
|
||||||
|
|
||||||
|
|
||||||
def similar():
|
def similar():
|
||||||
@ -158,6 +176,25 @@ def similar():
|
|||||||
print(list(zip(top_words, top_indices, top_probs)))
|
print(list(zip(top_words, top_indices, top_probs)))
|
||||||
|
|
||||||
|
|
||||||
|
def generate_outputs(input_file, output_file):
|
||||||
|
with open(output_file, 'w') as outputf:
|
||||||
|
with lzma.open(input_file, mode='rt') as file:
|
||||||
|
for index, text in enumerate(read_file_6(file)):
|
||||||
|
tokens = tokenizer.tokenize(text)
|
||||||
|
if len(tokens) < 4:
|
||||||
|
prediction = 'the:0.2 be:0.2 to:0.2 of:0.1 and:0.1 a:0.1 :0.1'
|
||||||
|
else:
|
||||||
|
prediction = predict(tokens[-1])
|
||||||
|
outputf.write(prediction + '\n')
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# train()
|
# train()
|
||||||
predict()
|
# predict()
|
||||||
|
# generate_outputs("dev-0/in.tsv.xz", "dev-0/out.tsv")
|
||||||
|
generate_outputs("test-A/in.tsv.xz", "test-A/out.tsv")
|
||||||
|
# count_words = 0
|
||||||
|
# for i in get_words_lines_from_file('train/in.tsv.xz'):
|
||||||
|
# for j in i:
|
||||||
|
# count_words += 1
|
||||||
|
# print(count_words)
|
||||||
|
14828
test-A/out.tsv
14828
test-A/out.tsv
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user