from torch import nn import torch from torch.utils.data import DataLoader import copy from torch.utils.data import IterableDataset import itertools import lzma import regex as re import pickle import scripts import string import pdb import utils def divide_chunks(l, n): # looping till length l for i in range(0, len(l), n): yield l[i:i + n] with open("vocab.pickle", 'rb') as handle: vocab = pickle.load( handle) vocab.set_default_index(vocab['']) def look_ahead_iterator(gen): seq = [] counter = 0 for item in gen: seq.append(item) if counter % 11 == 0 and counter !=0: if len(seq) == 11: yield seq seq = [] counter+=1 def get_word_lines_from_file(file_name): counter=0 with lzma.open(file_name, 'r') as fh: for line in fh: counter+=1 # if counter == 100000: # break line = line.decode("utf-8") yield scripts.get_words_from_line(line) class Grams_10(IterableDataset): def load_vocab(self): with open("vocab.pickle", 'rb') as handle: vocab = pickle.load( handle) return vocab def __init__(self, text_file, vocab): self.vocab = vocab self.vocab.set_default_index(self.vocab['']) self.text_file = text_file def __iter__(self): return look_ahead_iterator( (self.vocab[t] for t in itertools.chain.from_iterable(get_word_lines_from_file(self.text_file)))) vocab_size = scripts.vocab_size train_dataset = Grams_10('train/in.tsv.xz', vocab) BATCH_SIZE = 2048 train_data = DataLoader(train_dataset, batch_size=BATCH_SIZE) PREFIX_TRAIN = 'train' PREFIX_VALID = 'dev-0' BATCHES = [] # def read_train_file(folder_prefix, vocab): # dataset_x = [] # dataset_y = [] # counter_lines = 0 # seq_len = 10 # with lzma.open(f'{folder_prefix}/in.tsv.xz', 'r') as train, open(f'{folder_prefix}/expected.tsv', 'r') as expected: # for t_line, e_line in zip(train, expected): # t_line = t_line.decode("utf-8") # t_line = t_line.rstrip() # e_line = e_line.rstrip() # t_line = t_line.translate(str.maketrans('', '', string.punctuation)) # t_line_splitted_by_tab = t_line.split('\t') # # t_line_cleared = t_line_splitted_by_tab[-2] + ' ' + e_line + ' ' + t_line_splitted_by_tab[-1] # whole_line = t_line_splitted_by_tab[-2] + ' ' + e_line + ' ' + t_line_splitted_by_tab[-1] # whole_line_splitted = list(scripts.get_words_from_line(whole_line)) # whole_lines_splitted = divide_chunks(whole_line_splitted, 11) # for chunk_line in whole_line_splitted: # left_context_splitted = chunk_line[0:10] # seq_x = [] # for i in range(seq_len): # index = -1 - i # if len(left_context_splitted) < i + 1: # seq_x.insert(0, '') # else: # seq_x.insert(0, left_context_splitted[-1 -i]) # left_vocabed = [vocab[t] for t in seq_x] # dataset_x.append(left_vocabed ) # dataset_y.append([vocab[chunk_line[10]]]) # counter_lines+=1 # # if counter_lines > 20000: # # break # return dataset_x, dataset_y def read_dev_file(folder_prefix, vocab): dataset_x = [] dataset_y = [] counter_lines = 0 seq_len = 10 with lzma.open(f'{folder_prefix}/in.tsv.xz', 'r') as train, open(f'{folder_prefix}/expected.tsv', 'r') as expected: for t_line, e_line in zip(train, expected): t_line = t_line.decode("utf-8") t_line = t_line.rstrip() e_line = e_line.rstrip() t_line = t_line.translate(str.maketrans('', '', string.punctuation)) t_line_splitted_by_tab = t_line.split('\t') # t_line_cleared = t_line_splitted_by_tab[-2] + ' ' + e_line + ' ' + t_line_splitted_by_tab[-1] left_context = t_line_splitted_by_tab[-2] left_context_splitted = list(scripts.get_words_from_line(left_context)) seq_x = [] for i in range(seq_len): index = -1 - i if len(left_context_splitted) < i + 1: seq_x.insert(0, '') else: seq_x.insert(0, left_context_splitted[-1 -i]) left_vocabed = [vocab[t] for t in seq_x] dataset_x.append(left_vocabed ) dataset_y.append([vocab[e_line]]) counter_lines+=1 # if counter_lines > 20000: # break return dataset_x, dataset_y def read_test_file(folder_prefix, vocab): dataset_x = [] dataset_y = [] counter_lines = 0 seq_len = 10 with lzma.open(f'{folder_prefix}/in.tsv.xz', 'r') as train: for t_line in train: t_line = t_line.decode("utf-8") t_line = t_line.rstrip() t_line = t_line.translate(str.maketrans('', '', string.punctuation)) t_line_splitted_by_tab = t_line.split('\t') # t_line_cleared = t_line_splitted_by_tab[-2] + ' ' + e_line + ' ' + t_line_splitted_by_tab[-1] left_context = t_line_splitted_by_tab[-2] left_context_splitted = list(scripts.get_words_from_line(left_context)) seq_x = [] for i in range(seq_len): index = -1 - i if len(left_context_splitted) < i + 1: seq_x.insert(0, '') else: seq_x.insert(0, left_context_splitted[-1 -i]) left_vocabed = [vocab[t] for t in seq_x] dataset_x.append(left_vocabed ) counter_lines+=1 # if counter_lines > 20000: # break return dataset_x # train_set_x, train_set_y = read_file(PREFIX_TRAIN, vocab) dev_set_x, dev_set_y = read_dev_file(PREFIX_VALID, vocab) test_set_x = read_test_file('test-A', vocab) # train_data_x = DataLoader(train_set_x, batch_size=4048) # train_data_y = DataLoader(train_set_y, batch_size=4048) # train_data_x = DataLoader(train_set_x, batch_size=4048) # train_data_y = DataLoader(train_set_y, batch_size=4048) dev_data_x = DataLoader(dev_set_x, batch_size=1) dev_data_y = DataLoader(dev_set_y, batch_size=1) test_set_x = DataLoader(test_set_x, batch_size=1) # pdb.set_trace() device = utils.device model = utils.LanguageModel(scripts.vocab_size, utils.embed_size).to(device) optimizer = torch.optim.Adam(model.parameters(), lr=utils.learning_rate) criterion = torch.nn.NLLLoss() model.train() step = 0 last_best_acc = -1 epochs = 3 for epoch in range(epochs): model.train() for batch in train_data: x = batch[:10] y = [batch[10]] x = [i.to(device) for i in x] y = y[0].to(device) optimizer.zero_grad() ypredicted = model(x) # pdb.set_trace() loss = criterion(torch.log(ypredicted), y) if step % 10000 == 0: print('Step: ', step, loss) # torch.save(model.state_dict(), f'model1_{step}.bin') step += 1 loss.backward() optimizer.step() # evaluation model.eval() y_predeicted = [] top_50_true = 0 for d_x, d_y in zip(dev_data_x, dev_data_y): # pdb.set_trace() d_x = [i.to(device) for i in d_x] # d_y = d_y.to(device) optimizer.zero_grad() ypredicted = model(d_x) top = torch.topk(ypredicted[0], 64) top_indices = top.indices.tolist() if d_y[0] in top_indices: top_50_true+=1 my_acc = top_50_true/len(dev_data_y) print('My_accuracy: ', my_acc, ", epoch: ", epoch) if my_acc > last_best_acc: print('NEW BEST -- My_accuracy: ', my_acc, ", epoch: ", epoch) last_best_acc = my_acc best_model = copy.deepcopy(model) torch.save(model.state_dict(), f'model_last_best_.bin') if epoch % 15 == 0: print('Epoch: ', epoch, step, loss) # torch.save(model.state_dict(), f'model_epoch_{epoch}_.bin') # inference print('inference') inference_result = [] for d_x, d_y in zip(dev_data_x, dev_data_y): # pdb.set_trace() d_x = [i.to(device) for i in d_x] # d_y = d_y.to(device) optimizer.zero_grad() ypredicted = model(d_x) top = torch.topk(ypredicted[0], 10) top_indices = top.indices.tolist() top_probs = top.values.tolist() top_words = vocab.lookup_tokens(top_indices) string_to_print = '' sum_probs = 0 for w, p in zip(top_words, top_probs): # print(top_words) if '' in w: continue string_to_print += f"{w}:{p} " sum_probs += p if string_to_print == '': inference_result.append("the:0.2 a:0.3 :0.5") continue unknow_prob = 1 - sum_probs string_to_print += f":{unknow_prob}" inference_result.append(string_to_print) with open('dev-0/out.tsv', 'w') as f: for line in inference_result: f.write(line+'\n') print('inference test') inference_result = [] for d_x in test_set_x: # pdb.set_trace() d_x = [i.to(device) for i in d_x] # d_y = d_y.to(device) optimizer.zero_grad() ypredicted = model(d_x) top = torch.topk(ypredicted[0], 64) top_indices = top.indices.tolist() top_probs = top.values.tolist() top_words = vocab.lookup_tokens(top_indices) string_to_print = '' sum_probs = 0 for w, p in zip(top_words, top_probs): # print(top_words) if '' in w: continue string_to_print += f"{w}:{p} " sum_probs += p if string_to_print == '': inference_result.append("the:0.2 a:0.3 :0.5") continue unknow_prob = 1 - sum_probs string_to_print += f":{unknow_prob}" inference_result.append(string_to_print) with open('test-A/out.tsv', 'w') as f: for line in inference_result: f.write(line+'\n') print('All done')