forked from kubapok/lalka-lm
34 KiB
34 KiB
import nltk
import torch
import pandas as pd
import csv
from sklearn.model_selection import train_test_split
from nltk.tokenize import word_tokenize as tokenize
from tqdm.notebook import tqdm
import numpy as np
#downloads
nltk.download('punkt')
[nltk_data] Downloading package punkt to /home/ubuntu/nltk_data... [nltk_data] Package punkt is already up-to-date!
True
#settings
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('Using {} device'.format(device))
criterion = torch.nn.CrossEntropyLoss()
BATCH_SIZE = 128
EPOCHS = 15
NGRAMS = 5
Using cpu device
#training data prepare
train_data = pd.read_csv('train/train.tsv', header=None, error_bad_lines=False, quoting=csv.QUOTE_NONE, sep='\t')
train_data = train_data[0]
train_set, train_test_set = train_test_split(train_data, test_size = 0.2)
with open("train/train_set.tsv", "w", encoding='utf-8') as out_train_set:
for i in train_set:
out_train_set.write(i)
with open("train/train_test_set.tsv", "w", encoding='utf-8') as out_train_test_set:
for i in train_test_set:
out_train_test_set.write(i)
train_set_tok = list(tokenize(open('train/train_set.tsv').read()))
train_set_tok = [line.lower() for line in train_set_tok]
vocab_itos = sorted(set(train_set_tok))
print(len(vocab_itos))
vocab_itos = vocab_itos[:15005]
vocab_itos[15001] = "<UNK>"
vocab_itos[15002] = "<BOS>"
vocab_itos[15003] = "<EOS>"
vocab_itos[15004] = "<PAD>"
print(len(vocab_itos))
vocab_stoi = dict()
for i, token in enumerate(vocab_itos):
vocab_stoi[token] = i
train_ids = [vocab_stoi['<PAD>']] * (NGRAMS-1) + [vocab_stoi['<BOS>']]
for token in train_set_tok:
try:
train_ids.append(vocab_stoi[token])
except KeyError:
train_ids.append(vocab_stoi['<UNK>'])
train_ids.append(vocab_stoi['<EOS>'])
samples = []
for i in range(len(train_ids)-NGRAMS):
samples.append(train_ids[i:i+NGRAMS])
train_ids = torch.tensor(samples,device=device)
train_test_set_tok = list(tokenize(open('train/train_test_set.tsv').read()))
train_test_set_tok = [line.lower() for line in train_test_set_tok]
train_test_ids = [vocab_stoi['<PAD>']] * (NGRAMS-1) + [vocab_stoi['<BOS>']]
for token in train_test_set_tok:
try:
train_test_ids.append(vocab_stoi[token])
except KeyError:
train_test_ids.append(vocab_stoi['<UNK>'])
train_test_ids.append(vocab_stoi['<EOS>'])
samples = []
for i in range(len(train_test_ids)-NGRAMS):
samples.append(train_test_ids[i:i+NGRAMS])
train_test_ids = torch.tensor(samples, dtype=torch.long, device=device)
28558 15005
#GRU
class GRU(torch.nn.Module):
def __init__(self):
super(GRU, self).__init__()
self.emb = torch.nn.Embedding(len(vocab_itos),100)
self.rec = torch.nn.GRU(100, 256, 1, batch_first = True)
self.fc1 = torch.nn.Linear( 256 ,len(vocab_itos))
self.dropout = torch.nn.Dropout(0.5)
def forward(self, x):
emb = self.emb(x)
#emb = self.dropout(emb)
output, h_n = self.rec(emb)
hidden = h_n.squeeze(0)
out = self.fc1(hidden)
out = self.dropout(out)
return out
lm = GRU().to(device)
optimizer = torch.optim.Adam(lm.parameters(),lr=0.0001)
hppl_train = []
hppl_train_test = []
for epoch in range(EPOCHS):
batches = 0
loss_sum =0
acc_score = 0
lm.train()
for i in tqdm(range(0, len(train_ids)-BATCH_SIZE+1, BATCH_SIZE)):
X = train_ids[i:i+BATCH_SIZE,:NGRAMS-1]
Y = train_ids[i:i+BATCH_SIZE,NGRAMS-1]
predictions = lm(X)
loss = criterion(predictions,Y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_sum += loss.item()
batches += 1
#ppl train
lm.eval()
batches = 0
loss_sum =0
acc_score = 0
for i in range(0, len(train_ids)-BATCH_SIZE+1, BATCH_SIZE):
X = train_ids[i:i+BATCH_SIZE,:NGRAMS-1]
Y = train_ids[i:i+BATCH_SIZE,NGRAMS-1]
predictions = lm(X)
loss = criterion(predictions,Y)
loss_sum += loss.item()
batches += 1
ppl_train = np.exp(loss_sum / batches)
#ppl train test
lm.eval()
batches = 0
loss_sum =0
acc_score = 0
for i in range(0, len(train_test_ids)-BATCH_SIZE+1, BATCH_SIZE):
X = train_test_ids[i:i+BATCH_SIZE,:NGRAMS-1]
Y = train_test_ids[i:i+BATCH_SIZE,NGRAMS-1]
predictions = lm(X)
loss = criterion(predictions,Y)
loss_sum += loss.item()
batches += 1
ppl_train_test = np.exp(loss_sum / batches)
hppl_train.append(ppl_train)
hppl_train_test.append(ppl_train_test)
print('epoch: ', epoch)
print('train ppl: ', ppl_train)
print('train_test ppl: ', ppl_train_test)
print()
HBox(children=(FloatProgress(value=0.0, max=1629.0), HTML(value='')))
epoch: 0 train ppl: 429.60890594777385 train_test ppl: 354.7605940026038
HBox(children=(FloatProgress(value=0.0, max=1629.0), HTML(value='')))
epoch: 1 train ppl: 385.04263303807164 train_test ppl: 320.5323274780826
HBox(children=(FloatProgress(value=0.0, max=1629.0), HTML(value='')))
epoch: 2 train ppl: 388.15715746591627 train_test ppl: 331.5143312260392
HBox(children=(FloatProgress(value=0.0, max=1629.0), HTML(value='')))
epoch: 3 train ppl: 364.4566197255965 train_test ppl: 316.9918140368464
HBox(children=(FloatProgress(value=0.0, max=1629.0), HTML(value='')))
epoch: 4 train ppl: 344.1713452631125 train_test ppl: 306.67499426384535
HBox(children=(FloatProgress(value=0.0, max=1629.0), HTML(value='')))
epoch: 5 train ppl: 325.7237671473614 train_test ppl: 295.83423173746667
HBox(children=(FloatProgress(value=0.0, max=1629.0), HTML(value='')))
epoch: 6 train ppl: 323.8838574773216 train_test ppl: 302.95495879615413
HBox(children=(FloatProgress(value=0.0, max=1629.0), HTML(value='')))
epoch: 7 train ppl: 313.13238735049896 train_test ppl: 300.0722307805052
HBox(children=(FloatProgress(value=0.0, max=1629.0), HTML(value='')))
epoch: 8 train ppl: 308.2248282795148 train_test ppl: 303.25779664571974
HBox(children=(FloatProgress(value=0.0, max=1629.0), HTML(value='')))
epoch: 9 train ppl: 293.68307666273853 train_test ppl: 295.00145166486533
HBox(children=(FloatProgress(value=0.0, max=1629.0), HTML(value='')))
epoch: 10 train ppl: 279.2453691179102 train_test ppl: 287.8307587065576
HBox(children=(FloatProgress(value=0.0, max=1629.0), HTML(value='')))
epoch: 11 train ppl: 267.2034758169644 train_test ppl: 282.18074183208086
HBox(children=(FloatProgress(value=0.0, max=1629.0), HTML(value='')))
epoch: 12 train ppl: 260.65159391269935 train_test ppl: 281.92398288442536
HBox(children=(FloatProgress(value=0.0, max=1629.0), HTML(value='')))
epoch: 13 train ppl: 246.21807765812747 train_test ppl: 271.8481103799856
HBox(children=(FloatProgress(value=0.0, max=1629.0), HTML(value='')))
epoch: 14 train ppl: 234.50125342517168 train_test ppl: 265.61149027211843
#'Gości' i 'Lalka'
tokenized = list(tokenize('Gości innych nie widział oprócz spółleśników'))
tokenized = [token.lower() for token in tokenized]
ids = []
for word in tokenized:
if word in vocab_stoi:
ids.append(vocab_stoi[word])
else:
ids.append(vocab_stoi['<UNK>'])
lm.eval()
ids = torch.tensor(ids, dtype = torch.long, device = device)
preds= lm(ids.unsqueeze(0))
vocab_itos[torch.argmax(torch.softmax(preds,1),1).item()]
tokenized = list(tokenize('Lalka'))
tokenized = [token.lower() for token in tokenized]
ids = []
for word in tokenized:
if word in vocab_stoi:
ids.append(vocab_stoi[word])
else:
ids.append(vocab_stoi['<UNK>'])
ids = torch.tensor([ids], dtype = torch.long, device = device)
candidates_number = 10
for i in range(30):
preds= lm(ids)
candidates = torch.topk(torch.softmax(preds,1),candidates_number)[1][0].cpu().numpy()
candidate = 15001
while candidate > 15000:
candidate = candidates[np.random.randint(candidates_number)]
print(vocab_itos[candidate])
ids = torch.cat((ids, torch.tensor([[candidate]], device = device)), 1)
człowiek i nagle .— nie będzie , nie jestem pewna do niego i nie , jak pan ; jest . a jeżeli , nawet po . na lewo po kilka
#dev0 pred
with open("dev-0/in.tsv", "r", encoding='utf-8') as dev_path:
nr_of_dev_lines = len(dev_path.readlines())
with open("dev-0/out.tsv", "w", encoding='utf-8') as out_dev_file:
for i in range(nr_of_dev_lines):
preds= lm(ids)
candidates = torch.topk(torch.softmax(preds,1),candidates_number)[1][0].cpu().numpy()
candidate = 15001
while candidate > 15000:
candidate = candidates[np.random.randint(candidates_number)]
print(vocab_itos[candidate])
ids = torch.cat((ids, torch.tensor([[candidate]], device = device)), 1)
out_dev_file.write(vocab_itos[candidate] + '\n')
. o mnie nie było .— ani . jest jak . ale co pan nie obchodzi ! nawet nie jest ! ? . i jeszcze do . po co do pani , który , nawet , jak ona do panny ; i nawet : o co na myśl ! . po , jak i ja ? . a jeżeli nie o o ? po nie był pani .— . pan mnie nie , nawet mnie o .— . nie jestem , jak on , jak nie , nawet i nie . a jeżeli co ? i kto ? ! na jego ostrzyżonej ) ? do mnie i do na mnie i po co i jeszcze : czy nie , pani dobrodziejko ! na nie i po jego na lewo , ale , który na niego nie było ; nie i nie ma na , a pani nie mam ? . nie może na mnie i jeszcze nie mam ? ale , i już , nie mam . i cóż ! ) . nie jestem o mnie nie i nic ? i ja .— nie chcę , na lewo nie było na jej , nie na jej nie , ażeby jak . ale nie było o nią i , a nawet nie jest . nie chcę . a co pan do niej , który , na jego . była już : , i nawet go o nim ; o jej nie było na niego albo i . gdy go .— co mi do domu ? albo i , a pan , panie nie ! ! ! ja i na jej ochronę do , co mnie nie mam .— może , a nie ma na mnie nie , ani i nawet nie na nic ! . po chwili .— nie ma pan ignacy .— może mnie nie ? nawet ? po chwili . nie był ; na myśl , a nawet mnie ? do na nią ; i jeszcze jak on . i nawet do końca na jego nie i nawet do domu ? i o co dzień do pani ? a , czy nie chcę .— ja ? i o . ja , bo nie ma być ? , nie mam na co .— , ja ? , co ? ) do pana . na lewo . nie na nic . ale nie , a ja ? , a co do pani . była do pani meliton : albo o , ażeby , ale co , jak ona na niego ; . ale jeszcze na , na jego miejscu i była .— i ja .— na nią nie było .— co do mnie , ale nawet , do licha na myśl i do .— o mnie pan na co dzień na głowie .— co . nie jest ci .— pan . nie
#testA pred
with open("test-A/in.tsv", "r", encoding='utf-8') as test_a_path:
nr_of_test_a_lines = len(test_a_path.readlines())
with open("test-A/out.tsv", "w", encoding='utf-8') as out_test_file:
for i in range(nr_of_dev_lines):
preds= lm(ids)
candidates = torch.topk(torch.softmax(preds,1),candidates_number)[1][0].cpu().numpy()
candidate = 15001
while candidate > 15000:
candidate = candidates[np.random.randint(candidates_number)]
print(vocab_itos[candidate])
ids = torch.cat((ids, torch.tensor([[candidate]], device = device)), 1)
out_test_file.write(vocab_itos[candidate] + '\n')
, a ja . na co na mnie i kto , ale nawet na mnie ! co ja .— już ? ! ) i pan na myśl ; , a nawet nie , jak pan na mnie na , i o co ja nie chcę .— , nie mam ? ? nie . pani nie jest na co nie może i cóż nie . a jeżeli jak ona ! na dole . nie był pan . nie jest jeszcze jak pani ? i o ? po ? po co dzień ? na , co pan do niego na głowie .— . nie był . na myśl ; i ja . na lewo ; była go , na jej .— o ! ? na co ! ) do głowy . i nawet do niej . nie był ; o ile o jego o ; ale pan ignacy .— nie ma pan do . ja do mego i nie będzie o mnie i już . o co pan ignacy ? na którym , kiedy go na jej ; ale co , a co pan ? i kto mu pan , co ? o , i kto by mnie do głowy .— a ! nawet o niej na myśl ? i już do . nie na mnie nie mam . była już . ( , nie ! , jak on mnie .— pan . ( może na nie było i , który by mu nie . i dopiero . a , jak ja , na którym ? a jeżeli jest bardzo ? ! , bo już . nie chcę go do paryża .— co dzień pan nie . ? co na myśl ! , a może jeszcze na niego , nie ma , a pan nie będzie .— nic mnie pan . * . ja nie , pani dobrodziejko .— i cóż . pan nie jadł na nich ! ; na lewo na mnie i na nogi ? .— nie chcę ? , co by ? ! o ? po i nawet , jak ja . ale o jej ! , jak ja już nic ! ) ! cha , ale nawet do głowy na , nie mógł nawet nie mógł do niego nie na mnie ? ) , ale jeszcze . po . o mnie na jego na myśl i nawet na lewo na głowie na górę i po otworzeniu ; ale co do na jego .— a pan i co . jest pan ignacy do paryża nie mam . a jeżeli na jej ? . o nie i nie . o jego po pokoju , jak ja już : od na do ; ale nawet o niej nie jest , ale , jak , na jej . nie był ani ani do , a na nią : nawet co nie . na