348 lines
9.9 KiB
Python
348 lines
9.9 KiB
Python
|
from torch import nn
|
||
|
import torch
|
||
|
from torch.utils.data import DataLoader
|
||
|
import copy
|
||
|
from torch.utils.data import IterableDataset
|
||
|
import itertools
|
||
|
import lzma
|
||
|
import regex as re
|
||
|
import pickle
|
||
|
import scripts
|
||
|
import string
|
||
|
import pdb
|
||
|
import utils
|
||
|
|
||
|
def divide_chunks(l, n):
|
||
|
|
||
|
# looping till length l
|
||
|
for i in range(0, len(l), n):
|
||
|
yield l[i:i + n]
|
||
|
|
||
|
|
||
|
with open("vocab.pickle", 'rb') as handle:
|
||
|
vocab = pickle.load( handle)
|
||
|
vocab.set_default_index(vocab['<unk>'])
|
||
|
|
||
|
|
||
|
|
||
|
def look_ahead_iterator(gen):
|
||
|
seq = []
|
||
|
counter = 0
|
||
|
for item in gen:
|
||
|
seq.append(item)
|
||
|
if counter % 11 == 0 and counter !=0:
|
||
|
if len(seq) == 11:
|
||
|
yield seq
|
||
|
seq = []
|
||
|
counter+=1
|
||
|
|
||
|
def get_word_lines_from_file(file_name):
|
||
|
counter=0
|
||
|
with lzma.open(file_name, 'r') as fh:
|
||
|
for line in fh:
|
||
|
counter+=1
|
||
|
# if counter == 100000:
|
||
|
# break
|
||
|
line = line.decode("utf-8")
|
||
|
yield scripts.get_words_from_line(line)
|
||
|
|
||
|
|
||
|
|
||
|
class Grams_10(IterableDataset):
|
||
|
def load_vocab(self):
|
||
|
with open("vocab.pickle", 'rb') as handle:
|
||
|
vocab = pickle.load( handle)
|
||
|
return vocab
|
||
|
|
||
|
def __init__(self, text_file, vocab):
|
||
|
self.vocab = vocab
|
||
|
self.vocab.set_default_index(self.vocab['<unk>'])
|
||
|
self.text_file = text_file
|
||
|
|
||
|
def __iter__(self):
|
||
|
return look_ahead_iterator(
|
||
|
(self.vocab[t] for t in itertools.chain.from_iterable(get_word_lines_from_file(self.text_file))))
|
||
|
|
||
|
vocab_size = scripts.vocab_size
|
||
|
|
||
|
train_dataset = Grams_10('train/in.tsv.xz', vocab)
|
||
|
BATCH_SIZE = 2048
|
||
|
|
||
|
train_data = DataLoader(train_dataset, batch_size=BATCH_SIZE)
|
||
|
|
||
|
|
||
|
PREFIX_TRAIN = 'train'
|
||
|
PREFIX_VALID = 'dev-0'
|
||
|
BATCHES = []
|
||
|
# def read_train_file(folder_prefix, vocab):
|
||
|
# dataset_x = []
|
||
|
# dataset_y = []
|
||
|
# counter_lines = 0
|
||
|
# seq_len = 10
|
||
|
# with lzma.open(f'{folder_prefix}/in.tsv.xz', 'r') as train, open(f'{folder_prefix}/expected.tsv', 'r') as expected:
|
||
|
# for t_line, e_line in zip(train, expected):
|
||
|
# t_line = t_line.decode("utf-8")
|
||
|
# t_line = t_line.rstrip()
|
||
|
# e_line = e_line.rstrip()
|
||
|
# t_line = t_line.translate(str.maketrans('', '', string.punctuation))
|
||
|
# t_line_splitted_by_tab = t_line.split('\t')
|
||
|
# # t_line_cleared = t_line_splitted_by_tab[-2] + ' ' + e_line + ' ' + t_line_splitted_by_tab[-1]
|
||
|
|
||
|
|
||
|
# whole_line = t_line_splitted_by_tab[-2] + ' ' + e_line + ' ' + t_line_splitted_by_tab[-1]
|
||
|
|
||
|
# whole_line_splitted = list(scripts.get_words_from_line(whole_line))
|
||
|
|
||
|
# whole_lines_splitted = divide_chunks(whole_line_splitted, 11)
|
||
|
|
||
|
# for chunk_line in whole_line_splitted:
|
||
|
|
||
|
|
||
|
# left_context_splitted = chunk_line[0:10]
|
||
|
|
||
|
# seq_x = []
|
||
|
# for i in range(seq_len):
|
||
|
# index = -1 - i
|
||
|
# if len(left_context_splitted) < i + 1:
|
||
|
# seq_x.insert(0, '<empty>')
|
||
|
# else:
|
||
|
# seq_x.insert(0, left_context_splitted[-1 -i])
|
||
|
|
||
|
# left_vocabed = [vocab[t] for t in seq_x]
|
||
|
|
||
|
|
||
|
# dataset_x.append(left_vocabed )
|
||
|
# dataset_y.append([vocab[chunk_line[10]]])
|
||
|
|
||
|
# counter_lines+=1
|
||
|
# # if counter_lines > 20000:
|
||
|
# # break
|
||
|
# return dataset_x, dataset_y
|
||
|
|
||
|
def read_dev_file(folder_prefix, vocab):
|
||
|
dataset_x = []
|
||
|
dataset_y = []
|
||
|
counter_lines = 0
|
||
|
seq_len = 10
|
||
|
with lzma.open(f'{folder_prefix}/in.tsv.xz', 'r') as train, open(f'{folder_prefix}/expected.tsv', 'r') as expected:
|
||
|
for t_line, e_line in zip(train, expected):
|
||
|
t_line = t_line.decode("utf-8")
|
||
|
t_line = t_line.rstrip()
|
||
|
e_line = e_line.rstrip()
|
||
|
t_line = t_line.translate(str.maketrans('', '', string.punctuation))
|
||
|
t_line_splitted_by_tab = t_line.split('\t')
|
||
|
# t_line_cleared = t_line_splitted_by_tab[-2] + ' ' + e_line + ' ' + t_line_splitted_by_tab[-1]
|
||
|
|
||
|
left_context = t_line_splitted_by_tab[-2]
|
||
|
left_context_splitted = list(scripts.get_words_from_line(left_context))
|
||
|
|
||
|
|
||
|
seq_x = []
|
||
|
for i in range(seq_len):
|
||
|
index = -1 - i
|
||
|
if len(left_context_splitted) < i + 1:
|
||
|
seq_x.insert(0, '<empty>')
|
||
|
else:
|
||
|
seq_x.insert(0, left_context_splitted[-1 -i])
|
||
|
|
||
|
left_vocabed = [vocab[t] for t in seq_x]
|
||
|
|
||
|
|
||
|
dataset_x.append(left_vocabed )
|
||
|
dataset_y.append([vocab[e_line]])
|
||
|
|
||
|
counter_lines+=1
|
||
|
# if counter_lines > 20000:
|
||
|
# break
|
||
|
return dataset_x, dataset_y
|
||
|
|
||
|
def read_test_file(folder_prefix, vocab):
|
||
|
dataset_x = []
|
||
|
dataset_y = []
|
||
|
counter_lines = 0
|
||
|
seq_len = 10
|
||
|
with lzma.open(f'{folder_prefix}/in.tsv.xz', 'r') as train:
|
||
|
for t_line in train:
|
||
|
t_line = t_line.decode("utf-8")
|
||
|
t_line = t_line.rstrip()
|
||
|
t_line = t_line.translate(str.maketrans('', '', string.punctuation))
|
||
|
t_line_splitted_by_tab = t_line.split('\t')
|
||
|
# t_line_cleared = t_line_splitted_by_tab[-2] + ' ' + e_line + ' ' + t_line_splitted_by_tab[-1]
|
||
|
|
||
|
left_context = t_line_splitted_by_tab[-2]
|
||
|
left_context_splitted = list(scripts.get_words_from_line(left_context))
|
||
|
|
||
|
|
||
|
seq_x = []
|
||
|
for i in range(seq_len):
|
||
|
index = -1 - i
|
||
|
if len(left_context_splitted) < i + 1:
|
||
|
seq_x.insert(0, '<empty>')
|
||
|
else:
|
||
|
seq_x.insert(0, left_context_splitted[-1 -i])
|
||
|
|
||
|
left_vocabed = [vocab[t] for t in seq_x]
|
||
|
|
||
|
|
||
|
dataset_x.append(left_vocabed )
|
||
|
|
||
|
counter_lines+=1
|
||
|
# if counter_lines > 20000:
|
||
|
# break
|
||
|
return dataset_x
|
||
|
|
||
|
|
||
|
|
||
|
# train_set_x, train_set_y = read_file(PREFIX_TRAIN, vocab)
|
||
|
dev_set_x, dev_set_y = read_dev_file(PREFIX_VALID, vocab)
|
||
|
|
||
|
test_set_x = read_test_file('test-A', vocab)
|
||
|
|
||
|
# train_data_x = DataLoader(train_set_x, batch_size=4048)
|
||
|
# train_data_y = DataLoader(train_set_y, batch_size=4048)
|
||
|
|
||
|
# train_data_x = DataLoader(train_set_x, batch_size=4048)
|
||
|
# train_data_y = DataLoader(train_set_y, batch_size=4048)
|
||
|
|
||
|
|
||
|
dev_data_x = DataLoader(dev_set_x, batch_size=1)
|
||
|
dev_data_y = DataLoader(dev_set_y, batch_size=1)
|
||
|
|
||
|
|
||
|
|
||
|
test_set_x = DataLoader(test_set_x, batch_size=1)
|
||
|
# pdb.set_trace()
|
||
|
device = utils.device
|
||
|
|
||
|
model = utils.LanguageModel(scripts.vocab_size, utils.embed_size).to(device)
|
||
|
optimizer = torch.optim.Adam(model.parameters(), lr=utils.learning_rate)
|
||
|
criterion = torch.nn.NLLLoss()
|
||
|
model.train()
|
||
|
|
||
|
step = 0
|
||
|
last_best_acc = -1
|
||
|
epochs = 3
|
||
|
for epoch in range(epochs):
|
||
|
model.train()
|
||
|
for batch in train_data:
|
||
|
x = batch[:10]
|
||
|
y = [batch[10]]
|
||
|
|
||
|
x = [i.to(device) for i in x]
|
||
|
y = y[0].to(device)
|
||
|
optimizer.zero_grad()
|
||
|
ypredicted = model(x)
|
||
|
# pdb.set_trace()
|
||
|
loss = criterion(torch.log(ypredicted), y)
|
||
|
if step % 10000 == 0:
|
||
|
print('Step: ', step, loss)
|
||
|
# torch.save(model.state_dict(), f'model1_{step}.bin')
|
||
|
step += 1
|
||
|
loss.backward()
|
||
|
optimizer.step()
|
||
|
|
||
|
# evaluation
|
||
|
model.eval()
|
||
|
y_predeicted = []
|
||
|
top_50_true = 0
|
||
|
for d_x, d_y in zip(dev_data_x, dev_data_y):
|
||
|
# pdb.set_trace()
|
||
|
d_x = [i.to(device) for i in d_x]
|
||
|
# d_y = d_y.to(device)
|
||
|
|
||
|
optimizer.zero_grad()
|
||
|
ypredicted = model(d_x)
|
||
|
|
||
|
top = torch.topk(ypredicted[0], 64)
|
||
|
top_indices = top.indices.tolist()
|
||
|
if d_y[0] in top_indices:
|
||
|
top_50_true+=1
|
||
|
my_acc = top_50_true/len(dev_data_y)
|
||
|
print('My_accuracy: ', my_acc, ", epoch: ", epoch)
|
||
|
if my_acc > last_best_acc:
|
||
|
print('NEW BEST -- My_accuracy: ', my_acc, ", epoch: ", epoch)
|
||
|
last_best_acc = my_acc
|
||
|
best_model = copy.deepcopy(model)
|
||
|
torch.save(model.state_dict(), f'model_last_best_.bin')
|
||
|
if epoch % 15 == 0:
|
||
|
print('Epoch: ', epoch, step, loss)
|
||
|
# torch.save(model.state_dict(), f'model_epoch_{epoch}_.bin')
|
||
|
|
||
|
|
||
|
|
||
|
# inference
|
||
|
print('inference')
|
||
|
inference_result = []
|
||
|
for d_x, d_y in zip(dev_data_x, dev_data_y):
|
||
|
# pdb.set_trace()
|
||
|
d_x = [i.to(device) for i in d_x]
|
||
|
# d_y = d_y.to(device)
|
||
|
|
||
|
optimizer.zero_grad()
|
||
|
ypredicted = model(d_x)
|
||
|
|
||
|
top = torch.topk(ypredicted[0], 10)
|
||
|
top_indices = top.indices.tolist()
|
||
|
top_probs = top.values.tolist()
|
||
|
top_words = vocab.lookup_tokens(top_indices)
|
||
|
|
||
|
string_to_print = ''
|
||
|
|
||
|
sum_probs = 0
|
||
|
for w, p in zip(top_words, top_probs):
|
||
|
# print(top_words)
|
||
|
if '<unk>' in w:
|
||
|
continue
|
||
|
string_to_print += f"{w}:{p} "
|
||
|
sum_probs += p
|
||
|
|
||
|
if string_to_print == '':
|
||
|
inference_result.append("the:0.2 a:0.3 :0.5")
|
||
|
continue
|
||
|
unknow_prob = 1 - sum_probs
|
||
|
string_to_print += f":{unknow_prob}"
|
||
|
|
||
|
inference_result.append(string_to_print)
|
||
|
|
||
|
with open('dev-0/out.tsv', 'w') as f:
|
||
|
for line in inference_result:
|
||
|
f.write(line+'\n')
|
||
|
|
||
|
|
||
|
print('inference test')
|
||
|
inference_result = []
|
||
|
for d_x in test_set_x:
|
||
|
# pdb.set_trace()
|
||
|
d_x = [i.to(device) for i in d_x]
|
||
|
# d_y = d_y.to(device)
|
||
|
|
||
|
optimizer.zero_grad()
|
||
|
ypredicted = model(d_x)
|
||
|
|
||
|
top = torch.topk(ypredicted[0], 64)
|
||
|
top_indices = top.indices.tolist()
|
||
|
top_probs = top.values.tolist()
|
||
|
top_words = vocab.lookup_tokens(top_indices)
|
||
|
|
||
|
string_to_print = ''
|
||
|
|
||
|
sum_probs = 0
|
||
|
for w, p in zip(top_words, top_probs):
|
||
|
# print(top_words)
|
||
|
if '<unk>' in w:
|
||
|
continue
|
||
|
string_to_print += f"{w}:{p} "
|
||
|
sum_probs += p
|
||
|
|
||
|
if string_to_print == '':
|
||
|
inference_result.append("the:0.2 a:0.3 :0.5")
|
||
|
continue
|
||
|
unknow_prob = 1 - sum_probs
|
||
|
string_to_print += f":{unknow_prob}"
|
||
|
|
||
|
inference_result.append(string_to_print)
|
||
|
|
||
|
with open('test-A/out.tsv', 'w') as f:
|
||
|
for line in inference_result:
|
||
|
f.write(line+'\n')
|
||
|
print('All done')
|