2022-03-31 14:53:43 +02:00
|
|
|
#!/usr/bin/env python
|
|
|
|
# coding: utf-8
|
|
|
|
|
2022-05-29 23:56:57 +02:00
|
|
|
# In[2]:
|
2022-03-31 14:53:43 +02:00
|
|
|
|
|
|
|
|
2022-05-29 23:56:57 +02:00
|
|
|
import torch
|
|
|
|
from torch import nn, optim
|
|
|
|
from torch.utils.data import DataLoader
|
|
|
|
import numpy as np
|
|
|
|
from collections import Counter
|
|
|
|
import re
|
2022-03-31 14:53:43 +02:00
|
|
|
import lzma
|
|
|
|
import csv
|
|
|
|
|
|
|
|
|
2022-05-29 23:56:57 +02:00
|
|
|
# In[3]:
|
|
|
|
|
|
|
|
|
|
|
|
device = 'cuda'
|
|
|
|
|
|
|
|
|
|
|
|
# In[4]:
|
|
|
|
|
|
|
|
|
|
|
|
class Dataset(torch.utils.data.Dataset):
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
sequence_length,
|
|
|
|
):
|
|
|
|
self.sequence_length = sequence_length
|
|
|
|
self.words = self.load()
|
|
|
|
self.uniq_words = self.get_uniq_words()
|
|
|
|
|
|
|
|
self.index_to_word = {index: word for index, word in enumerate(self.uniq_words)}
|
|
|
|
self.word_to_index = {word: index for index, word in enumerate(self.uniq_words)}
|
|
|
|
|
|
|
|
self.words_indexes = [self.word_to_index[w] for w in self.words]
|
|
|
|
|
|
|
|
def load(self):
|
|
|
|
data = lzma.open(f'train/in.tsv.xz').read().decode('UTF-8').split('\n')
|
|
|
|
data = [line.split('\t') for line in data][:-1]
|
|
|
|
data = [[i[6].replace('\\\\n', ' '), i[7].replace('\\\\n', ' ')] for i in data]
|
|
|
|
|
2022-03-31 21:07:24 +02:00
|
|
|
words = []
|
2022-05-29 23:56:57 +02:00
|
|
|
with open(f'train/expected.tsv') as file:
|
2022-03-31 21:07:24 +02:00
|
|
|
tsv_file = csv.reader(file, delimiter="\t")
|
|
|
|
for line in tsv_file:
|
|
|
|
words.append(line[0])
|
2022-05-29 23:56:57 +02:00
|
|
|
|
|
|
|
text = []
|
|
|
|
# for i in range(len(data) - 1):
|
|
|
|
for i in range(5000):
|
|
|
|
t = data[i][0] + ' ' + words[i] + ' ' + data[i][1] + ' '
|
|
|
|
text += [t.replace('\\n', ' ')]
|
2022-03-31 14:53:43 +02:00
|
|
|
|
2022-05-29 23:56:57 +02:00
|
|
|
text = ' '.join(text).lower()
|
|
|
|
text = re.sub('[^a-z ]', '', text)
|
|
|
|
text = text.split(' ')
|
|
|
|
return text
|
2022-03-31 21:07:24 +02:00
|
|
|
|
2022-05-29 23:56:57 +02:00
|
|
|
|
|
|
|
def get_uniq_words(self):
|
|
|
|
word_counts = Counter(self.words)
|
|
|
|
return sorted(word_counts, key=word_counts.get, reverse=True)
|
2022-03-31 14:53:43 +02:00
|
|
|
|
2022-05-29 23:56:57 +02:00
|
|
|
def __len__(self):
|
|
|
|
return len(self.words_indexes) - self.sequence_length
|
2022-03-31 14:53:43 +02:00
|
|
|
|
2022-05-29 23:56:57 +02:00
|
|
|
def __getitem__(self, index):
|
|
|
|
return (
|
|
|
|
torch.tensor(self.words_indexes[index:index+self.sequence_length]),
|
|
|
|
torch.tensor(self.words_indexes[index+1:index+self.sequence_length+1]),
|
|
|
|
)
|
2022-03-31 14:53:43 +02:00
|
|
|
|
|
|
|
|
2022-05-29 23:56:57 +02:00
|
|
|
# In[5]:
|
|
|
|
|
|
|
|
|
|
|
|
dataset = Dataset(5)
|
|
|
|
|
|
|
|
|
|
|
|
# In[6]:
|
|
|
|
|
|
|
|
|
|
|
|
dataset[200]
|
|
|
|
|
|
|
|
|
|
|
|
# In[7]:
|
|
|
|
|
|
|
|
|
|
|
|
[dataset.index_to_word[x] for x in [ 0, 231, 19, 98, 189]]
|
|
|
|
|
|
|
|
|
|
|
|
# In[8]:
|
|
|
|
|
|
|
|
|
|
|
|
[dataset.index_to_word[x] for x in [231, 19, 98, 189, 5]]
|
|
|
|
|
|
|
|
|
|
|
|
# In[9]:
|
|
|
|
|
|
|
|
|
|
|
|
input_tensor = torch.tensor([[ 0, 231, 19, 98, 189]], dtype=torch.int32).to(device)
|
|
|
|
|
|
|
|
|
|
|
|
# In[ ]:
|
|
|
|
|
|
|
|
|
|
|
|
class Model(nn.Module):
|
|
|
|
def __init__(self, vocab_size):
|
|
|
|
super(Model, self).__init__()
|
|
|
|
self.lstm_size = 128
|
|
|
|
self.embedding_dim = 128
|
|
|
|
self.num_layers = 3
|
|
|
|
|
|
|
|
self.embedding = nn.Embedding(
|
|
|
|
num_embeddings=vocab_size,
|
|
|
|
embedding_dim=self.embedding_dim,
|
|
|
|
)
|
|
|
|
self.lstm = nn.LSTM(
|
|
|
|
input_size=self.lstm_size,
|
|
|
|
hidden_size=self.lstm_size,
|
|
|
|
num_layers=self.num_layers,
|
|
|
|
dropout=0.2,
|
|
|
|
)
|
|
|
|
self.fc = nn.Linear(self.lstm_size, vocab_size)
|
|
|
|
|
|
|
|
def forward(self, x, prev_state = None):
|
|
|
|
embed = self.embedding(x)
|
|
|
|
output, state = self.lstm(embed, prev_state)
|
|
|
|
logits = self.fc(output)
|
|
|
|
return logits, state
|
|
|
|
|
|
|
|
def init_state(self, sequence_length):
|
|
|
|
return (torch.zeros(self.num_layers, sequence_length, self.lstm_size).to(device),
|
|
|
|
torch.zeros(self.num_layers, sequence_length, self.lstm_size).to(device))
|
|
|
|
|
|
|
|
|
|
|
|
# In[ ]:
|
|
|
|
|
|
|
|
|
|
|
|
model = Model(len(dataset)).to(device)
|
|
|
|
|
|
|
|
|
|
|
|
# In[ ]:
|
|
|
|
|
|
|
|
|
|
|
|
y_pred, state_h = model(input_tensor)
|
|
|
|
|
|
|
|
|
|
|
|
# In[ ]:
|
|
|
|
|
|
|
|
|
|
|
|
y_pred
|
|
|
|
|
|
|
|
|
|
|
|
# In[ ]:
|
|
|
|
|
|
|
|
|
|
|
|
y_pred.shape
|
|
|
|
|
|
|
|
|
|
|
|
# In[ ]:
|
|
|
|
|
|
|
|
|
|
|
|
def train(dataset, model, max_epochs, batch_size):
|
|
|
|
model.train()
|
|
|
|
|
|
|
|
dataloader = DataLoader(dataset, batch_size=batch_size)
|
|
|
|
criterion = nn.CrossEntropyLoss()
|
|
|
|
optimizer = optim.Adam(model.parameters(), lr=0.001)
|
|
|
|
|
|
|
|
for epoch in range(max_epochs):
|
|
|
|
for batch, (x, y) in enumerate(dataloader):
|
|
|
|
optimizer.zero_grad()
|
|
|
|
x = x.to(device)
|
|
|
|
y = y.to(device)
|
|
|
|
|
|
|
|
y_pred, state_h = model(x)
|
|
|
|
loss = criterion(y_pred.transpose(1, 2), y)
|
|
|
|
|
|
|
|
loss.backward()
|
|
|
|
optimizer.step()
|
|
|
|
|
|
|
|
print({ 'epoch': epoch, 'update in batch': batch, '/' : len(dataloader), 'loss': loss.item() })
|
2022-03-31 14:53:43 +02:00
|
|
|
|
|
|
|
|
2022-05-29 23:56:57 +02:00
|
|
|
# In[ ]:
|
2022-03-31 14:53:43 +02:00
|
|
|
|
|
|
|
|
2022-05-29 23:56:57 +02:00
|
|
|
model = Model(vocab_size = len(dataset.uniq_words)).to(device)
|
|
|
|
train(dataset, model, 1, 64)
|
|
|
|
|
|
|
|
|
|
|
|
# In[ ]:
|
|
|
|
|
|
|
|
|
|
|
|
def predict(dataset, model, text, next_words=5):
|
|
|
|
model.eval()
|
|
|
|
words = text.split(' ')
|
|
|
|
state_h = model.init_state(len(words))
|
|
|
|
res = []
|
|
|
|
|
|
|
|
x = torch.tensor([[dataset.word_to_index[w] for w in words]]).to(device)
|
|
|
|
y_pred, state_h = model(x, state_h)
|
|
|
|
|
|
|
|
last_word_logits = y_pred[0][-1]
|
|
|
|
p = torch.nn.functional.softmax(last_word_logits, dim=0).detach().cpu().numpy()
|
|
|
|
tmp = sorted(zip(p, range(len(p))), reverse=True)[:next_words]
|
|
|
|
for w in tmp:
|
|
|
|
res.append((dataset.index_to_word[w[1]], w[0]))
|
|
|
|
|
|
|
|
return res
|
|
|
|
|
|
|
|
def predict2(dataset, model, model2, text, text2, next_words=5):
|
|
|
|
model.eval()
|
|
|
|
model2.eval()
|
|
|
|
words = text.split(' ')
|
|
|
|
words2 = text2.split(' ')
|
|
|
|
words2.reverse()
|
|
|
|
state_h = model.init_state(len(words))
|
|
|
|
state_h_2 = model2.init_state(len(words))
|
|
|
|
res = []
|
|
|
|
|
|
|
|
x = torch.tensor([[dataset.word_to_index[w] for w in words]]).to(device)
|
|
|
|
x2 = torch.tensor([[dataset.word_to_index[w] for w in words2]]).to(device)
|
|
|
|
y_pred, state_h = model(x, state_h)
|
|
|
|
y_pred_2, state_h_2 = model2(x2, state_h_2)
|
|
|
|
|
|
|
|
last_word_logits = y_pred[0][-1]
|
|
|
|
last_word_logits_2 = y_pred_2[0][-1]
|
|
|
|
p = torch.nn.functional.softmax(last_word_logits, dim=0).detach().cpu().numpy()
|
|
|
|
p2 = torch.nn.functional.softmax(last_word_logits_2, dim=0).detach().cpu().numpy()
|
|
|
|
|
|
|
|
p_mean = [(g + h) / 2 for g, h in zip(p, p2)]
|
|
|
|
|
|
|
|
tmp = sorted(zip(p_mean, range(len(p_mean))), reverse=True)[:next_words]
|
|
|
|
for w in tmp:
|
|
|
|
res.append((dataset.index_to_word[w[1]], w[0]))
|
|
|
|
|
|
|
|
return res
|
|
|
|
|
|
|
|
|
|
|
|
# In[ ]:
|
|
|
|
|
|
|
|
|
|
|
|
predict(dataset, model, 'it is a')
|
|
|
|
|
|
|
|
|
|
|
|
# In[69]:
|
|
|
|
|
|
|
|
|
|
|
|
dev_data = lzma.open(f'dev-0/in.tsv.xz').read().decode('UTF-8').split('\n')
|
|
|
|
dev_data = [line.split('\t') for line in dev_data][:-1]
|
|
|
|
dev_data1 = [re.sub('[^a-z ]', '', i[6].replace('\\n', ' ').lower()).strip() for i in dev_data]
|
|
|
|
dev_data2 = [re.sub('[^a-z ]', '', i[7].replace('\\n', ' ').lower()).strip() for i in dev_data]
|
|
|
|
|
|
|
|
|
|
|
|
# In[23]:
|
|
|
|
|
|
|
|
|
|
|
|
dev_data[0]
|
|
|
|
|
|
|
|
|
|
|
|
# In[54]:
|
|
|
|
|
|
|
|
|
|
|
|
print(predict(dataset, model, ' '.join(dev_data[9].split()[-1:])))
|
|
|
|
|
|
|
|
|
|
|
|
# In[66]:
|
|
|
|
|
|
|
|
|
|
|
|
class ReversedDataset(torch.utils.data.Dataset):
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
sequence_length,
|
|
|
|
):
|
|
|
|
self.sequence_length = sequence_length
|
|
|
|
self.words = self.load()
|
|
|
|
self.uniq_words = self.get_uniq_words()
|
|
|
|
|
|
|
|
self.index_to_word = {index: word for index, word in enumerate(self.uniq_words)}
|
|
|
|
self.word_to_index = {word: index for index, word in enumerate(self.uniq_words)}
|
|
|
|
|
|
|
|
self.words_indexes = [self.word_to_index[w] for w in self.words]
|
|
|
|
|
|
|
|
def load(self):
|
|
|
|
data = lzma.open(f'train/in.tsv.xz').read().decode('UTF-8').split('\n')
|
|
|
|
data = [line.split('\t') for line in data][:-1]
|
|
|
|
data = [[i[6].replace('\\\\n', ' '), i[7].replace('\\\\n', ' ')] for i in data]
|
|
|
|
|
|
|
|
words = []
|
|
|
|
with open(f'train/expected.tsv') as file:
|
|
|
|
tsv_file = csv.reader(file, delimiter="\t")
|
|
|
|
for line in tsv_file:
|
|
|
|
words.append(line[0])
|
|
|
|
|
|
|
|
text = []
|
|
|
|
# for i in range(len(data) - 1):
|
|
|
|
for i in range(5000):
|
|
|
|
t = data[i][0] + ' ' + words[i] + ' ' + data[i][1] + ' '
|
|
|
|
text += [t.replace('\\n', ' ')]
|
|
|
|
|
|
|
|
text = ' '.join(text).lower()
|
|
|
|
text = re.sub('[^a-z ]', '', text)
|
|
|
|
text = text.split(' ')
|
|
|
|
text.reverse()
|
|
|
|
return text
|
2022-03-31 14:53:43 +02:00
|
|
|
|
2022-03-31 21:07:24 +02:00
|
|
|
|
2022-05-29 23:56:57 +02:00
|
|
|
def get_uniq_words(self):
|
|
|
|
word_counts = Counter(self.words)
|
|
|
|
return sorted(word_counts, key=word_counts.get, reverse=True)
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
|
return len(self.words_indexes) - self.sequence_length
|
|
|
|
|
|
|
|
def __getitem__(self, index):
|
|
|
|
return (
|
|
|
|
torch.tensor(self.words_indexes[index:index+self.sequence_length]),
|
|
|
|
torch.tensor(self.words_indexes[index+1:index+self.sequence_length+1]),
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# In[67]:
|
|
|
|
|
|
|
|
|
|
|
|
dataset_2 = ReversedDataset(5)
|
|
|
|
input_tensor_2 = torch.tensor([[ 0, 231, 19, 98, 189]], dtype=torch.int32).to(device)
|
|
|
|
|
|
|
|
model_2 = Model(len(dataset_2)).to(device)
|
|
|
|
y_pred_2, state_h_2 = model(input_tensor_2)
|
|
|
|
model_2 = Model(vocab_size = len(dataset_2.uniq_words)).to(device)
|
|
|
|
train(dataset_2, model_2, 1, 64)
|
|
|
|
|
|
|
|
|
|
|
|
# In[96]:
|
|
|
|
|
|
|
|
|
|
|
|
n = 2
|
|
|
|
|
|
|
|
f = open("dev-0/out.tsv", "w")
|
|
|
|
|
|
|
|
for i in range(len(dev_data1)):
|
|
|
|
d1 = dev_data1[i]
|
|
|
|
d2 = dev_data2[i]
|
|
|
|
try:
|
|
|
|
tmp = predict2(dataset, model, model_2, ' '.join(d1.split()[-n:]), ' '.join(d2.split()[:n]))
|
|
|
|
f.writelines(' '.join([f'{i[0]}:{i[1]}' for i in tmp]) + ' :0.3\n')
|
|
|
|
except:
|
|
|
|
f.writelines(':1\n')
|
|
|
|
|
|
|
|
f.close()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# In[95]:
|
|
|
|
|
|
|
|
|
|
|
|
len(dev_data1)
|
|
|
|
|
|
|
|
|
|
|
|
# In[93]:
|
|
|
|
|
|
|
|
|
|
|
|
test_data = lzma.open(f'test-A/in.tsv.xz').read().decode('UTF-8').split('\n')
|
|
|
|
test_data = [line.split('\t') for line in test_data][:-1]
|
|
|
|
test_data1 = [re.sub('[^a-z ]', '', i[6].replace('\\n', ' ').lower()).strip() for i in test_data]
|
|
|
|
test_data2 = [re.sub('[^a-z ]', '', i[7].replace('\\n', ' ').lower()).strip() for i in test_data]
|
|
|
|
|
|
|
|
|
|
|
|
n = 2
|
|
|
|
|
|
|
|
f = open("test-A/out.tsv", "w")
|
|
|
|
|
|
|
|
for i in range(len(test_data1)):
|
|
|
|
d1 = test_data1[i]
|
|
|
|
d2 = test_data2[i]
|
|
|
|
try:
|
|
|
|
tmp = predict2(dataset, model, model_2, ' '.join(d1.split()[-n:]), ' '.join(d2.split()[:n]))
|
|
|
|
f.writelines(' '.join([f'{i[0]}:{i[1]}' for i in tmp]) + ' :0.3\n')
|
|
|
|
except:
|
|
|
|
f.writelines(':1\n')
|
2022-03-31 14:53:43 +02:00
|
|
|
|
2022-05-29 23:56:57 +02:00
|
|
|
f.close()
|
|
|
|
|
2022-03-31 14:53:43 +02:00
|
|
|
|