forked from kubapok/lalka-lm
202 lines
6.1 KiB
Python
202 lines
6.1 KiB
Python
import numpy as np
|
|
import torch
|
|
from sklearn.model_selection import train_test_split
|
|
import nltk
|
|
from nltk.tokenize import word_tokenize
|
|
|
|
nltk.download('punkt')
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
train_a = "train/train.tsv"
|
|
lalka_path_train = 'train/train_train.tsv'
|
|
lalka_path_valid = 'train/train_test.tsv'
|
|
corpora_train = open(lalka_path_train).read()
|
|
corpora_train_tokenized = list(word_tokenize(corpora_train))
|
|
corpora_train_tokenized = [token.lower() for token in corpora_train_tokenized]
|
|
|
|
vocab_itos = sorted(set(corpora_train_tokenized))
|
|
vocab_itos = vocab_itos[:15005]
|
|
vocab_itos[15001] = "<UNK>"
|
|
vocab_itos[15002] = "<BOS>"
|
|
vocab_itos[15003] = "<EOS>"
|
|
vocab_itos[15004] = "<PAD>"
|
|
|
|
BATCH_SIZE = 128
|
|
EPOCHS = 15
|
|
|
|
|
|
history_ppl_train = []
|
|
history_ppl_valid = []
|
|
|
|
vocab_stoi = dict()
|
|
for i, token in enumerate(vocab_itos):
|
|
vocab_stoi[token] = i
|
|
|
|
NGRAMS = 5
|
|
|
|
|
|
def set_ppl(dataset_id_list):
|
|
lm.eval()
|
|
|
|
batches = 0
|
|
loss_sum = 0
|
|
|
|
for i in range(0, len(dataset_id_list) - BATCH_SIZE + 1, BATCH_SIZE):
|
|
X = dataset_id_list[i:i + BATCH_SIZE, :NGRAMS - 1]
|
|
Y = dataset_id_list[i:i + BATCH_SIZE, NGRAMS - 1]
|
|
predictions = lm(X)
|
|
|
|
loss = criterion(predictions, Y)
|
|
|
|
loss_sum += loss.item()
|
|
batches += 1
|
|
|
|
return np.exp(loss_sum / batches)
|
|
|
|
|
|
def open_files(path_a, path_b, path_c):
|
|
with open(path_a, "r") as path:
|
|
lines = path.readlines()
|
|
train, test = train_test_split(lines, test_size=0.2)
|
|
with open(path_b, "w") as out_train_file:
|
|
for i in train:
|
|
out_train_file.write(i)
|
|
with open(path_c, "w") as out_test_file:
|
|
for i in test:
|
|
out_test_file.write(i)
|
|
|
|
|
|
def get_samples(dataset):
|
|
samples = []
|
|
for i in range(len(dataset) - NGRAMS):
|
|
samples.append(dataset[i:i + NGRAMS])
|
|
return samples
|
|
|
|
|
|
def get_token_id(dataset):
|
|
token_id_list = [vocab_stoi['<PAD>']] * (NGRAMS - 1) + [vocab_stoi['<BOS>']]
|
|
for token in dataset:
|
|
try:
|
|
token_id_list.append(vocab_stoi[token])
|
|
except KeyError:
|
|
token_id_list.append(vocab_stoi['<UNK>'])
|
|
token_id_list.append(vocab_stoi['<EOS>'])
|
|
return token_id_list
|
|
|
|
|
|
open_files(train_a, lalka_path_train, lalka_path_valid)
|
|
train_id_list = get_token_id(corpora_train_tokenized)
|
|
train_id_list = get_samples(train_id_list)
|
|
train_id_list = torch.tensor(train_id_list, device=device)
|
|
corpora_valid = open(lalka_path_valid).read()
|
|
corpora_valid_tokenized = list(word_tokenize(corpora_valid))
|
|
corpora_valid_tokenized = [token.lower() for token in corpora_valid_tokenized]
|
|
valid_id_list = get_token_id(corpora_valid_tokenized)
|
|
valid_id_list = torch.tensor(get_samples(valid_id_list), dtype=torch.long, device=device)
|
|
|
|
|
|
class GRU(torch.nn.Module):
|
|
|
|
def __init__(self):
|
|
super(GRU, self).__init__()
|
|
self.emb = torch.nn.Embedding(len(vocab_itos), 100)
|
|
self.rec = torch.nn.GRU(100, 256, 1, batch_first=True)
|
|
self.fc1 = torch.nn.Linear(256, len(vocab_itos))
|
|
|
|
def forward(self, x):
|
|
emb = self.emb(x)
|
|
output, h_n = self.rec(emb)
|
|
hidden = h_n.squeeze(0)
|
|
out = self.fc1(hidden)
|
|
return out
|
|
|
|
|
|
lm = GRU().to(device)
|
|
criterion = torch.nn.CrossEntropyLoss()
|
|
optimizer = torch.optim.Adam(lm.parameters(), lr=0.0001)
|
|
|
|
for epoch in range(EPOCHS):
|
|
|
|
batches = 0
|
|
loss_sum = 0
|
|
acc_score = 0
|
|
lm.train()
|
|
for i in range(0, len(train_id_list) - BATCH_SIZE + 1, BATCH_SIZE):
|
|
X = train_id_list[i:i + BATCH_SIZE, :NGRAMS - 1]
|
|
Y = train_id_list[i:i + BATCH_SIZE, NGRAMS - 1]
|
|
predictions = lm(X)
|
|
loss = criterion(predictions, Y)
|
|
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
loss_sum += loss.item()
|
|
batches += 1
|
|
|
|
ppl_train = set_ppl(train_id_list)
|
|
ppl_valid = set_ppl(valid_id_list)
|
|
|
|
history_ppl_train.append(ppl_train)
|
|
history_ppl_valid.append(ppl_valid)
|
|
|
|
|
|
tokenized = list(word_tokenize('Gości innych nie widział oprócz spółleśników'))
|
|
tokenized = [token.lower() for token in tokenized]
|
|
|
|
id_list = []
|
|
for word in tokenized:
|
|
if word in vocab_stoi:
|
|
id_list.append(vocab_stoi[word])
|
|
else:
|
|
id_list.append(vocab_stoi['<UNK>'])
|
|
|
|
lm.eval()
|
|
|
|
id_list = torch.tensor(id_list, dtype=torch.long, device=device)
|
|
|
|
preds = lm(id_list.unsqueeze(0))
|
|
|
|
vocab_itos[torch.argmax(torch.softmax(preds, 1), 1).item()]
|
|
|
|
tokenized = list(word_tokenize('Lalka'))
|
|
tokenized = [token.lower() for token in tokenized]
|
|
|
|
id_list = []
|
|
for word in tokenized:
|
|
if word in vocab_stoi:
|
|
id_list.append(vocab_stoi[word])
|
|
else:
|
|
id_list.append(vocab_stoi['<UNK>'])
|
|
id_list = torch.tensor([id_list], dtype=torch.long, device=device)
|
|
|
|
candidates_number = 10
|
|
for i in range(30):
|
|
preds = lm(id_list)
|
|
candidates = torch.topk(torch.softmax(preds, 1), candidates_number)[1][0].cpu().numpy()
|
|
candidate = 15001
|
|
while candidate > 15000:
|
|
candidate = candidates[np.random.randint(candidates_number)]
|
|
id_list = torch.cat((id_list, torch.tensor([[candidate]], device=device)), 1)
|
|
|
|
with open("dev-0/in.tsv", "r") as dev_path:
|
|
nr_of_dev_lines = len(dev_path.readlines())
|
|
with open("test-A/in.tsv", "r") as test_a_path:
|
|
nr_of_test_a_lines = len(test_a_path.readlines())
|
|
with open("dev-0/out.tsv", "w") as out_dev_file:
|
|
for i in range(nr_of_dev_lines):
|
|
preds = lm(id_list)
|
|
candidates = torch.topk(torch.softmax(preds, 1), candidates_number)[1][0].cpu().numpy()
|
|
candidate = 15001
|
|
while candidate > 15000:
|
|
candidate = candidates[np.random.randint(candidates_number)]
|
|
id_list = torch.cat((id_list, torch.tensor([[candidate]], device=device)), 1)
|
|
out_dev_file.write(vocab_itos[candidate] + '\n')
|
|
with open("test-A/out.tsv", "w") as out_test_file:
|
|
for i in range(nr_of_dev_lines):
|
|
preds = lm(id_list)
|
|
candidates = torch.topk(torch.softmax(preds, 1), candidates_number)[1][0].cpu().numpy()
|
|
candidate = 15001
|
|
while candidate > 15000:
|
|
candidate = candidates[np.random.randint(candidates_number)]
|
|
id_list = torch.cat((id_list, torch.tensor([[candidate]], device=device)), 1)
|
|
out_test_file.write(vocab_itos[candidate] + '\n') |