challenging-america-word-ga.../run.py

390 lines
9.4 KiB
Python
Raw Normal View History

2022-03-31 14:53:43 +02:00
#!/usr/bin/env python
# coding: utf-8
2022-05-29 23:56:57 +02:00
# In[2]:
2022-03-31 14:53:43 +02:00
2022-05-29 23:56:57 +02:00
import torch
from torch import nn, optim
from torch.utils.data import DataLoader
import numpy as np
from collections import Counter
import re
2022-03-31 14:53:43 +02:00
import lzma
import csv
2022-05-29 23:56:57 +02:00
# In[3]:
device = 'cuda'
# In[4]:
class Dataset(torch.utils.data.Dataset):
def __init__(
self,
sequence_length,
):
self.sequence_length = sequence_length
self.words = self.load()
self.uniq_words = self.get_uniq_words()
self.index_to_word = {index: word for index, word in enumerate(self.uniq_words)}
self.word_to_index = {word: index for index, word in enumerate(self.uniq_words)}
self.words_indexes = [self.word_to_index[w] for w in self.words]
def load(self):
data = lzma.open(f'train/in.tsv.xz').read().decode('UTF-8').split('\n')
data = [line.split('\t') for line in data][:-1]
data = [[i[6].replace('\\\\n', ' '), i[7].replace('\\\\n', ' ')] for i in data]
2022-03-31 21:07:24 +02:00
words = []
2022-05-29 23:56:57 +02:00
with open(f'train/expected.tsv') as file:
2022-03-31 21:07:24 +02:00
tsv_file = csv.reader(file, delimiter="\t")
for line in tsv_file:
words.append(line[0])
2022-05-29 23:56:57 +02:00
text = []
# for i in range(len(data) - 1):
for i in range(5000):
t = data[i][0] + ' ' + words[i] + ' ' + data[i][1] + ' '
text += [t.replace('\\n', ' ')]
2022-03-31 14:53:43 +02:00
2022-05-29 23:56:57 +02:00
text = ' '.join(text).lower()
text = re.sub('[^a-z ]', '', text)
text = text.split(' ')
return text
2022-03-31 21:07:24 +02:00
2022-05-29 23:56:57 +02:00
def get_uniq_words(self):
word_counts = Counter(self.words)
return sorted(word_counts, key=word_counts.get, reverse=True)
2022-03-31 14:53:43 +02:00
2022-05-29 23:56:57 +02:00
def __len__(self):
return len(self.words_indexes) - self.sequence_length
2022-03-31 14:53:43 +02:00
2022-05-29 23:56:57 +02:00
def __getitem__(self, index):
return (
torch.tensor(self.words_indexes[index:index+self.sequence_length]),
torch.tensor(self.words_indexes[index+1:index+self.sequence_length+1]),
)
2022-03-31 14:53:43 +02:00
2022-05-29 23:56:57 +02:00
# In[5]:
dataset = Dataset(5)
# In[6]:
dataset[200]
# In[7]:
[dataset.index_to_word[x] for x in [ 0, 231, 19, 98, 189]]
# In[8]:
[dataset.index_to_word[x] for x in [231, 19, 98, 189, 5]]
# In[9]:
input_tensor = torch.tensor([[ 0, 231, 19, 98, 189]], dtype=torch.int32).to(device)
# In[ ]:
class Model(nn.Module):
def __init__(self, vocab_size):
super(Model, self).__init__()
self.lstm_size = 128
self.embedding_dim = 128
self.num_layers = 3
self.embedding = nn.Embedding(
num_embeddings=vocab_size,
embedding_dim=self.embedding_dim,
)
self.lstm = nn.LSTM(
input_size=self.lstm_size,
hidden_size=self.lstm_size,
num_layers=self.num_layers,
dropout=0.2,
)
self.fc = nn.Linear(self.lstm_size, vocab_size)
def forward(self, x, prev_state = None):
embed = self.embedding(x)
output, state = self.lstm(embed, prev_state)
logits = self.fc(output)
return logits, state
def init_state(self, sequence_length):
return (torch.zeros(self.num_layers, sequence_length, self.lstm_size).to(device),
torch.zeros(self.num_layers, sequence_length, self.lstm_size).to(device))
# In[ ]:
model = Model(len(dataset)).to(device)
# In[ ]:
y_pred, state_h = model(input_tensor)
# In[ ]:
y_pred
# In[ ]:
y_pred.shape
# In[ ]:
def train(dataset, model, max_epochs, batch_size):
model.train()
dataloader = DataLoader(dataset, batch_size=batch_size)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
for epoch in range(max_epochs):
for batch, (x, y) in enumerate(dataloader):
optimizer.zero_grad()
x = x.to(device)
y = y.to(device)
y_pred, state_h = model(x)
loss = criterion(y_pred.transpose(1, 2), y)
loss.backward()
optimizer.step()
print({ 'epoch': epoch, 'update in batch': batch, '/' : len(dataloader), 'loss': loss.item() })
2022-03-31 14:53:43 +02:00
2022-05-29 23:56:57 +02:00
# In[ ]:
2022-03-31 14:53:43 +02:00
2022-05-29 23:56:57 +02:00
model = Model(vocab_size = len(dataset.uniq_words)).to(device)
train(dataset, model, 1, 64)
# In[ ]:
def predict(dataset, model, text, next_words=5):
model.eval()
words = text.split(' ')
state_h = model.init_state(len(words))
res = []
x = torch.tensor([[dataset.word_to_index[w] for w in words]]).to(device)
y_pred, state_h = model(x, state_h)
last_word_logits = y_pred[0][-1]
p = torch.nn.functional.softmax(last_word_logits, dim=0).detach().cpu().numpy()
tmp = sorted(zip(p, range(len(p))), reverse=True)[:next_words]
for w in tmp:
res.append((dataset.index_to_word[w[1]], w[0]))
return res
def predict2(dataset, model, model2, text, text2, next_words=5):
model.eval()
model2.eval()
words = text.split(' ')
words2 = text2.split(' ')
words2.reverse()
state_h = model.init_state(len(words))
state_h_2 = model2.init_state(len(words))
res = []
x = torch.tensor([[dataset.word_to_index[w] for w in words]]).to(device)
x2 = torch.tensor([[dataset.word_to_index[w] for w in words2]]).to(device)
y_pred, state_h = model(x, state_h)
y_pred_2, state_h_2 = model2(x2, state_h_2)
last_word_logits = y_pred[0][-1]
last_word_logits_2 = y_pred_2[0][-1]
p = torch.nn.functional.softmax(last_word_logits, dim=0).detach().cpu().numpy()
p2 = torch.nn.functional.softmax(last_word_logits_2, dim=0).detach().cpu().numpy()
p_mean = [(g + h) / 2 for g, h in zip(p, p2)]
tmp = sorted(zip(p_mean, range(len(p_mean))), reverse=True)[:next_words]
for w in tmp:
res.append((dataset.index_to_word[w[1]], w[0]))
return res
# In[ ]:
predict(dataset, model, 'it is a')
# In[69]:
dev_data = lzma.open(f'dev-0/in.tsv.xz').read().decode('UTF-8').split('\n')
dev_data = [line.split('\t') for line in dev_data][:-1]
dev_data1 = [re.sub('[^a-z ]', '', i[6].replace('\\n', ' ').lower()).strip() for i in dev_data]
dev_data2 = [re.sub('[^a-z ]', '', i[7].replace('\\n', ' ').lower()).strip() for i in dev_data]
# In[23]:
dev_data[0]
# In[54]:
print(predict(dataset, model, ' '.join(dev_data[9].split()[-1:])))
# In[66]:
class ReversedDataset(torch.utils.data.Dataset):
def __init__(
self,
sequence_length,
):
self.sequence_length = sequence_length
self.words = self.load()
self.uniq_words = self.get_uniq_words()
self.index_to_word = {index: word for index, word in enumerate(self.uniq_words)}
self.word_to_index = {word: index for index, word in enumerate(self.uniq_words)}
self.words_indexes = [self.word_to_index[w] for w in self.words]
def load(self):
data = lzma.open(f'train/in.tsv.xz').read().decode('UTF-8').split('\n')
data = [line.split('\t') for line in data][:-1]
data = [[i[6].replace('\\\\n', ' '), i[7].replace('\\\\n', ' ')] for i in data]
words = []
with open(f'train/expected.tsv') as file:
tsv_file = csv.reader(file, delimiter="\t")
for line in tsv_file:
words.append(line[0])
text = []
# for i in range(len(data) - 1):
for i in range(5000):
t = data[i][0] + ' ' + words[i] + ' ' + data[i][1] + ' '
text += [t.replace('\\n', ' ')]
text = ' '.join(text).lower()
text = re.sub('[^a-z ]', '', text)
text = text.split(' ')
text.reverse()
return text
2022-03-31 14:53:43 +02:00
2022-03-31 21:07:24 +02:00
2022-05-29 23:56:57 +02:00
def get_uniq_words(self):
word_counts = Counter(self.words)
return sorted(word_counts, key=word_counts.get, reverse=True)
def __len__(self):
return len(self.words_indexes) - self.sequence_length
def __getitem__(self, index):
return (
torch.tensor(self.words_indexes[index:index+self.sequence_length]),
torch.tensor(self.words_indexes[index+1:index+self.sequence_length+1]),
)
# In[67]:
dataset_2 = ReversedDataset(5)
input_tensor_2 = torch.tensor([[ 0, 231, 19, 98, 189]], dtype=torch.int32).to(device)
model_2 = Model(len(dataset_2)).to(device)
y_pred_2, state_h_2 = model(input_tensor_2)
model_2 = Model(vocab_size = len(dataset_2.uniq_words)).to(device)
train(dataset_2, model_2, 1, 64)
# In[96]:
n = 2
f = open("dev-0/out.tsv", "w")
for i in range(len(dev_data1)):
d1 = dev_data1[i]
d2 = dev_data2[i]
try:
tmp = predict2(dataset, model, model_2, ' '.join(d1.split()[-n:]), ' '.join(d2.split()[:n]))
f.writelines(' '.join([f'{i[0]}:{i[1]}' for i in tmp]) + ' :0.3\n')
except:
f.writelines(':1\n')
f.close()
# In[95]:
len(dev_data1)
# In[93]:
test_data = lzma.open(f'test-A/in.tsv.xz').read().decode('UTF-8').split('\n')
test_data = [line.split('\t') for line in test_data][:-1]
test_data1 = [re.sub('[^a-z ]', '', i[6].replace('\\n', ' ').lower()).strip() for i in test_data]
test_data2 = [re.sub('[^a-z ]', '', i[7].replace('\\n', ' ').lower()).strip() for i in test_data]
n = 2
f = open("test-A/out.tsv", "w")
for i in range(len(test_data1)):
d1 = test_data1[i]
d2 = test_data2[i]
try:
tmp = predict2(dataset, model, model_2, ' '.join(d1.split()[-n:]), ' '.join(d2.split()[:n]))
f.writelines(' '.join([f'{i[0]}:{i[1]}' for i in tmp]) + ' :0.3\n')
except:
f.writelines(':1\n')
2022-03-31 14:53:43 +02:00
2022-05-29 23:56:57 +02:00
f.close()
2022-03-31 14:53:43 +02:00