done
This commit is contained in:
parent
259235ac26
commit
c5c85120f9
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
10519
dev-0/out.tsv
Normal file
10519
dev-0/out.tsv
Normal file
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
7414
test-A/out.tsv
Normal file
7414
test-A/out.tsv
Normal file
File diff suppressed because it is too large
Load Diff
29
x_create_vocab.py
Normal file
29
x_create_vocab.py
Normal file
@ -0,0 +1,29 @@
|
|||||||
|
from itertools import islice
|
||||||
|
import regex as re
|
||||||
|
import sys
|
||||||
|
from torchtext.vocab import build_vocab_from_iterator
|
||||||
|
import lzma
|
||||||
|
import utils
|
||||||
|
import torch
|
||||||
|
def get_word_lines_from_file(file_name):
|
||||||
|
counter=0
|
||||||
|
with lzma.open(file_name, 'r') as fh:
|
||||||
|
for line in fh:
|
||||||
|
counter+=1
|
||||||
|
# if counter == 4000:
|
||||||
|
# break
|
||||||
|
line = line.decode("utf-8")
|
||||||
|
yield utils.get_words_from_line(line)
|
||||||
|
|
||||||
|
|
||||||
|
vocab_size = utils.vocab_size
|
||||||
|
|
||||||
|
vocab = build_vocab_from_iterator(
|
||||||
|
get_word_lines_from_file('train/in.tsv.xz'),
|
||||||
|
max_tokens = vocab_size,
|
||||||
|
specials = ['<unk>', '<empty>'])
|
||||||
|
|
||||||
|
|
||||||
|
import pickle
|
||||||
|
with open("vocab.pickle", 'wb') as handle:
|
||||||
|
pickle.dump(vocab, handle)
|
348
x_train.py
Normal file
348
x_train.py
Normal file
@ -0,0 +1,348 @@
|
|||||||
|
from torch import nn
|
||||||
|
import torch
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
import copy
|
||||||
|
from torch.utils.data import IterableDataset
|
||||||
|
import itertools
|
||||||
|
import lzma
|
||||||
|
import regex as re
|
||||||
|
import pickle
|
||||||
|
import scripts
|
||||||
|
import string
|
||||||
|
import pdb
|
||||||
|
import utils
|
||||||
|
|
||||||
|
def divide_chunks(l, n):
|
||||||
|
|
||||||
|
# looping till length l
|
||||||
|
for i in range(0, len(l), n):
|
||||||
|
yield l[i:i + n]
|
||||||
|
|
||||||
|
|
||||||
|
with open("vocab.pickle", 'rb') as handle:
|
||||||
|
vocab = pickle.load( handle)
|
||||||
|
vocab.set_default_index(vocab['<unk>'])
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def look_ahead_iterator(gen):
|
||||||
|
seq = []
|
||||||
|
counter = 0
|
||||||
|
for item in gen:
|
||||||
|
seq.append(item)
|
||||||
|
if counter % 11 == 0 and counter !=0:
|
||||||
|
if len(seq) == 11:
|
||||||
|
yield seq
|
||||||
|
seq = []
|
||||||
|
counter+=1
|
||||||
|
|
||||||
|
def get_word_lines_from_file(file_name):
|
||||||
|
counter=0
|
||||||
|
with lzma.open(file_name, 'r') as fh:
|
||||||
|
for line in fh:
|
||||||
|
counter+=1
|
||||||
|
# if counter == 100000:
|
||||||
|
# break
|
||||||
|
line = line.decode("utf-8")
|
||||||
|
yield scripts.get_words_from_line(line)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
class Grams_10(IterableDataset):
|
||||||
|
def load_vocab(self):
|
||||||
|
with open("vocab.pickle", 'rb') as handle:
|
||||||
|
vocab = pickle.load( handle)
|
||||||
|
return vocab
|
||||||
|
|
||||||
|
def __init__(self, text_file, vocab):
|
||||||
|
self.vocab = vocab
|
||||||
|
self.vocab.set_default_index(self.vocab['<unk>'])
|
||||||
|
self.text_file = text_file
|
||||||
|
|
||||||
|
def __iter__(self):
|
||||||
|
return look_ahead_iterator(
|
||||||
|
(self.vocab[t] for t in itertools.chain.from_iterable(get_word_lines_from_file(self.text_file))))
|
||||||
|
|
||||||
|
vocab_size = scripts.vocab_size
|
||||||
|
|
||||||
|
train_dataset = Grams_10('train/in.tsv.xz', vocab)
|
||||||
|
BATCH_SIZE = 2048
|
||||||
|
|
||||||
|
train_data = DataLoader(train_dataset, batch_size=BATCH_SIZE)
|
||||||
|
|
||||||
|
|
||||||
|
PREFIX_TRAIN = 'train'
|
||||||
|
PREFIX_VALID = 'dev-0'
|
||||||
|
BATCHES = []
|
||||||
|
# def read_train_file(folder_prefix, vocab):
|
||||||
|
# dataset_x = []
|
||||||
|
# dataset_y = []
|
||||||
|
# counter_lines = 0
|
||||||
|
# seq_len = 10
|
||||||
|
# with lzma.open(f'{folder_prefix}/in.tsv.xz', 'r') as train, open(f'{folder_prefix}/expected.tsv', 'r') as expected:
|
||||||
|
# for t_line, e_line in zip(train, expected):
|
||||||
|
# t_line = t_line.decode("utf-8")
|
||||||
|
# t_line = t_line.rstrip()
|
||||||
|
# e_line = e_line.rstrip()
|
||||||
|
# t_line = t_line.translate(str.maketrans('', '', string.punctuation))
|
||||||
|
# t_line_splitted_by_tab = t_line.split('\t')
|
||||||
|
# # t_line_cleared = t_line_splitted_by_tab[-2] + ' ' + e_line + ' ' + t_line_splitted_by_tab[-1]
|
||||||
|
|
||||||
|
|
||||||
|
# whole_line = t_line_splitted_by_tab[-2] + ' ' + e_line + ' ' + t_line_splitted_by_tab[-1]
|
||||||
|
|
||||||
|
# whole_line_splitted = list(scripts.get_words_from_line(whole_line))
|
||||||
|
|
||||||
|
# whole_lines_splitted = divide_chunks(whole_line_splitted, 11)
|
||||||
|
|
||||||
|
# for chunk_line in whole_line_splitted:
|
||||||
|
|
||||||
|
|
||||||
|
# left_context_splitted = chunk_line[0:10]
|
||||||
|
|
||||||
|
# seq_x = []
|
||||||
|
# for i in range(seq_len):
|
||||||
|
# index = -1 - i
|
||||||
|
# if len(left_context_splitted) < i + 1:
|
||||||
|
# seq_x.insert(0, '<empty>')
|
||||||
|
# else:
|
||||||
|
# seq_x.insert(0, left_context_splitted[-1 -i])
|
||||||
|
|
||||||
|
# left_vocabed = [vocab[t] for t in seq_x]
|
||||||
|
|
||||||
|
|
||||||
|
# dataset_x.append(left_vocabed )
|
||||||
|
# dataset_y.append([vocab[chunk_line[10]]])
|
||||||
|
|
||||||
|
# counter_lines+=1
|
||||||
|
# # if counter_lines > 20000:
|
||||||
|
# # break
|
||||||
|
# return dataset_x, dataset_y
|
||||||
|
|
||||||
|
def read_dev_file(folder_prefix, vocab):
|
||||||
|
dataset_x = []
|
||||||
|
dataset_y = []
|
||||||
|
counter_lines = 0
|
||||||
|
seq_len = 10
|
||||||
|
with lzma.open(f'{folder_prefix}/in.tsv.xz', 'r') as train, open(f'{folder_prefix}/expected.tsv', 'r') as expected:
|
||||||
|
for t_line, e_line in zip(train, expected):
|
||||||
|
t_line = t_line.decode("utf-8")
|
||||||
|
t_line = t_line.rstrip()
|
||||||
|
e_line = e_line.rstrip()
|
||||||
|
t_line = t_line.translate(str.maketrans('', '', string.punctuation))
|
||||||
|
t_line_splitted_by_tab = t_line.split('\t')
|
||||||
|
# t_line_cleared = t_line_splitted_by_tab[-2] + ' ' + e_line + ' ' + t_line_splitted_by_tab[-1]
|
||||||
|
|
||||||
|
left_context = t_line_splitted_by_tab[-2]
|
||||||
|
left_context_splitted = list(scripts.get_words_from_line(left_context))
|
||||||
|
|
||||||
|
|
||||||
|
seq_x = []
|
||||||
|
for i in range(seq_len):
|
||||||
|
index = -1 - i
|
||||||
|
if len(left_context_splitted) < i + 1:
|
||||||
|
seq_x.insert(0, '<empty>')
|
||||||
|
else:
|
||||||
|
seq_x.insert(0, left_context_splitted[-1 -i])
|
||||||
|
|
||||||
|
left_vocabed = [vocab[t] for t in seq_x]
|
||||||
|
|
||||||
|
|
||||||
|
dataset_x.append(left_vocabed )
|
||||||
|
dataset_y.append([vocab[e_line]])
|
||||||
|
|
||||||
|
counter_lines+=1
|
||||||
|
# if counter_lines > 20000:
|
||||||
|
# break
|
||||||
|
return dataset_x, dataset_y
|
||||||
|
|
||||||
|
def read_test_file(folder_prefix, vocab):
|
||||||
|
dataset_x = []
|
||||||
|
dataset_y = []
|
||||||
|
counter_lines = 0
|
||||||
|
seq_len = 10
|
||||||
|
with lzma.open(f'{folder_prefix}/in.tsv.xz', 'r') as train:
|
||||||
|
for t_line in train:
|
||||||
|
t_line = t_line.decode("utf-8")
|
||||||
|
t_line = t_line.rstrip()
|
||||||
|
t_line = t_line.translate(str.maketrans('', '', string.punctuation))
|
||||||
|
t_line_splitted_by_tab = t_line.split('\t')
|
||||||
|
# t_line_cleared = t_line_splitted_by_tab[-2] + ' ' + e_line + ' ' + t_line_splitted_by_tab[-1]
|
||||||
|
|
||||||
|
left_context = t_line_splitted_by_tab[-2]
|
||||||
|
left_context_splitted = list(scripts.get_words_from_line(left_context))
|
||||||
|
|
||||||
|
|
||||||
|
seq_x = []
|
||||||
|
for i in range(seq_len):
|
||||||
|
index = -1 - i
|
||||||
|
if len(left_context_splitted) < i + 1:
|
||||||
|
seq_x.insert(0, '<empty>')
|
||||||
|
else:
|
||||||
|
seq_x.insert(0, left_context_splitted[-1 -i])
|
||||||
|
|
||||||
|
left_vocabed = [vocab[t] for t in seq_x]
|
||||||
|
|
||||||
|
|
||||||
|
dataset_x.append(left_vocabed )
|
||||||
|
|
||||||
|
counter_lines+=1
|
||||||
|
# if counter_lines > 20000:
|
||||||
|
# break
|
||||||
|
return dataset_x
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# train_set_x, train_set_y = read_file(PREFIX_TRAIN, vocab)
|
||||||
|
dev_set_x, dev_set_y = read_dev_file(PREFIX_VALID, vocab)
|
||||||
|
|
||||||
|
test_set_x = read_test_file('test-A', vocab)
|
||||||
|
|
||||||
|
# train_data_x = DataLoader(train_set_x, batch_size=4048)
|
||||||
|
# train_data_y = DataLoader(train_set_y, batch_size=4048)
|
||||||
|
|
||||||
|
# train_data_x = DataLoader(train_set_x, batch_size=4048)
|
||||||
|
# train_data_y = DataLoader(train_set_y, batch_size=4048)
|
||||||
|
|
||||||
|
|
||||||
|
dev_data_x = DataLoader(dev_set_x, batch_size=1)
|
||||||
|
dev_data_y = DataLoader(dev_set_y, batch_size=1)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
test_set_x = DataLoader(test_set_x, batch_size=1)
|
||||||
|
# pdb.set_trace()
|
||||||
|
device = utils.device
|
||||||
|
|
||||||
|
model = utils.LanguageModel(scripts.vocab_size, utils.embed_size).to(device)
|
||||||
|
optimizer = torch.optim.Adam(model.parameters(), lr=utils.learning_rate)
|
||||||
|
criterion = torch.nn.NLLLoss()
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
step = 0
|
||||||
|
last_best_acc = -1
|
||||||
|
epochs = 3
|
||||||
|
for epoch in range(epochs):
|
||||||
|
model.train()
|
||||||
|
for batch in train_data:
|
||||||
|
x = batch[:10]
|
||||||
|
y = [batch[10]]
|
||||||
|
|
||||||
|
x = [i.to(device) for i in x]
|
||||||
|
y = y[0].to(device)
|
||||||
|
optimizer.zero_grad()
|
||||||
|
ypredicted = model(x)
|
||||||
|
# pdb.set_trace()
|
||||||
|
loss = criterion(torch.log(ypredicted), y)
|
||||||
|
if step % 10000 == 0:
|
||||||
|
print('Step: ', step, loss)
|
||||||
|
# torch.save(model.state_dict(), f'model1_{step}.bin')
|
||||||
|
step += 1
|
||||||
|
loss.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
# evaluation
|
||||||
|
model.eval()
|
||||||
|
y_predeicted = []
|
||||||
|
top_50_true = 0
|
||||||
|
for d_x, d_y in zip(dev_data_x, dev_data_y):
|
||||||
|
# pdb.set_trace()
|
||||||
|
d_x = [i.to(device) for i in d_x]
|
||||||
|
# d_y = d_y.to(device)
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
ypredicted = model(d_x)
|
||||||
|
|
||||||
|
top = torch.topk(ypredicted[0], 64)
|
||||||
|
top_indices = top.indices.tolist()
|
||||||
|
if d_y[0] in top_indices:
|
||||||
|
top_50_true+=1
|
||||||
|
my_acc = top_50_true/len(dev_data_y)
|
||||||
|
print('My_accuracy: ', my_acc, ", epoch: ", epoch)
|
||||||
|
if my_acc > last_best_acc:
|
||||||
|
print('NEW BEST -- My_accuracy: ', my_acc, ", epoch: ", epoch)
|
||||||
|
last_best_acc = my_acc
|
||||||
|
best_model = copy.deepcopy(model)
|
||||||
|
torch.save(model.state_dict(), f'model_last_best_.bin')
|
||||||
|
if epoch % 15 == 0:
|
||||||
|
print('Epoch: ', epoch, step, loss)
|
||||||
|
# torch.save(model.state_dict(), f'model_epoch_{epoch}_.bin')
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# inference
|
||||||
|
print('inference')
|
||||||
|
inference_result = []
|
||||||
|
for d_x, d_y in zip(dev_data_x, dev_data_y):
|
||||||
|
# pdb.set_trace()
|
||||||
|
d_x = [i.to(device) for i in d_x]
|
||||||
|
# d_y = d_y.to(device)
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
ypredicted = model(d_x)
|
||||||
|
|
||||||
|
top = torch.topk(ypredicted[0], 10)
|
||||||
|
top_indices = top.indices.tolist()
|
||||||
|
top_probs = top.values.tolist()
|
||||||
|
top_words = vocab.lookup_tokens(top_indices)
|
||||||
|
|
||||||
|
string_to_print = ''
|
||||||
|
|
||||||
|
sum_probs = 0
|
||||||
|
for w, p in zip(top_words, top_probs):
|
||||||
|
# print(top_words)
|
||||||
|
if '<unk>' in w:
|
||||||
|
continue
|
||||||
|
string_to_print += f"{w}:{p} "
|
||||||
|
sum_probs += p
|
||||||
|
|
||||||
|
if string_to_print == '':
|
||||||
|
inference_result.append("the:0.2 a:0.3 :0.5")
|
||||||
|
continue
|
||||||
|
unknow_prob = 1 - sum_probs
|
||||||
|
string_to_print += f":{unknow_prob}"
|
||||||
|
|
||||||
|
inference_result.append(string_to_print)
|
||||||
|
|
||||||
|
with open('dev-0/out.tsv', 'w') as f:
|
||||||
|
for line in inference_result:
|
||||||
|
f.write(line+'\n')
|
||||||
|
|
||||||
|
|
||||||
|
print('inference test')
|
||||||
|
inference_result = []
|
||||||
|
for d_x in test_set_x:
|
||||||
|
# pdb.set_trace()
|
||||||
|
d_x = [i.to(device) for i in d_x]
|
||||||
|
# d_y = d_y.to(device)
|
||||||
|
|
||||||
|
optimizer.zero_grad()
|
||||||
|
ypredicted = model(d_x)
|
||||||
|
|
||||||
|
top = torch.topk(ypredicted[0], 64)
|
||||||
|
top_indices = top.indices.tolist()
|
||||||
|
top_probs = top.values.tolist()
|
||||||
|
top_words = vocab.lookup_tokens(top_indices)
|
||||||
|
|
||||||
|
string_to_print = ''
|
||||||
|
|
||||||
|
sum_probs = 0
|
||||||
|
for w, p in zip(top_words, top_probs):
|
||||||
|
# print(top_words)
|
||||||
|
if '<unk>' in w:
|
||||||
|
continue
|
||||||
|
string_to_print += f"{w}:{p} "
|
||||||
|
sum_probs += p
|
||||||
|
|
||||||
|
if string_to_print == '':
|
||||||
|
inference_result.append("the:0.2 a:0.3 :0.5")
|
||||||
|
continue
|
||||||
|
unknow_prob = 1 - sum_probs
|
||||||
|
string_to_print += f":{unknow_prob}"
|
||||||
|
|
||||||
|
inference_result.append(string_to_print)
|
||||||
|
|
||||||
|
with open('test-A/out.tsv', 'w') as f:
|
||||||
|
for line in inference_result:
|
||||||
|
f.write(line+'\n')
|
||||||
|
print('All done')
|
Loading…
Reference in New Issue
Block a user