8.8 KiB
8.8 KiB
import pandas as pd
import regex as re
import torch
import csv
from torch import nn
from collections import Counter
torch.cuda.empty_cache()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
class Dataset(torch.utils.data.Dataset):
def __init__(self, sequence_length, file_path):
self.file_path = file_path
self.sequence_length = sequence_length
self.words = self.load()
self.uniq_words = self.get_uniq_words()
self.index_to_word = {index: word for index, word in enumerate(self.uniq_words)}
self.word_to_index = {word: index for index, word in enumerate(self.uniq_words)}
self.words_indexes = [self.word_to_index[w] for w in self.words]
def load(self):
with open(self.file_path, 'r') as f_in:
text = [x.rstrip() for x in f_in.readlines() if x.strip()]
text = ' '.join(text).lower()
text = re.sub('[^a-ząćęłńóśźż ]', '', text)
text = text.split(' ')
return text
def get_uniq_words(self):
word_counts = Counter(self.words)
return sorted(word_counts, key=word_counts.get, reverse=True)
def __len__(self):
return len(self.words_indexes) - self.sequence_length
def __getitem__(self, index):
return (
torch.tensor(self.words_indexes[index:index+self.sequence_length]),
torch.tensor(self.words_indexes[index+1:index+self.sequence_length+1]),
)
class Model(nn.Module):
def __init__(self, vocab_size):
super(Model, self).__init__()
self.lstm_size = 128
self.embedding_dim = 256
self.num_layers = 3
self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=self.embedding_dim)
self.lstm = nn.LSTM(input_size=self.lstm_size, hidden_size=self.lstm_size, num_layers=self.num_layers, dropout=0.2)
self.fc = nn.Linear(self.lstm_size, vocab_size)
def forward(self, x, prev_state = None):
embed = self.embedding(x)
output, state = self.lstm(embed, prev_state)
logits = self.fc(output)
return logits, state
def init_state(self, sequence_length):
zeros = torch.zeros(self.num_layers, sequence_length, self.lstm_size).to(device)
return (zeros, zeros)
def clean_text(text):
text = text.lower().replace('-\\\\\\\\\\\\\\\\n', '').replace('\\\\\\\\\\\\\\\\n', ' ')
text = re.sub(r'\p{P}', '', text)
text = text.replace("'t", " not").replace("'s", " is").replace("'ll", " will").replace("'m", " am").replace("'ve", " have")
return text
train_data = pd.read_csv('train/in.tsv.xz', sep='\t', error_bad_lines=False, warn_bad_lines=False, header=None, quoting=csv.QUOTE_NONE)
train_labels = pd.read_csv('train/expected.tsv', sep='\t', error_bad_lines=False, warn_bad_lines=False, header=None, quoting=csv.QUOTE_NONE)
train_data = train_data[[6, 7]]
train_data = pd.concat([train_data, train_labels], axis=1)
train_data['text'] = train_data[6] + train_data[0] + train_data[7]
train_data = train_data[['text']]
with open('processed_train.txt', 'w', encoding='utf-8') as file:
for _, row in train_data.iterrows():
text = clean_text(str(row['text']))
file.write(text + '\n')
data = Dataset(5, 'processed_train.txt')
[0;31m---------------------------------------------------------------------------[0m [0;31mNameError[0m Traceback (most recent call last) [0;32m/tmp/ipykernel_14895/2199368365.py[0m in [0;36m<module>[0;34m[0m [1;32m 1[0m [0mdata[0m [0;34m=[0m [0mDataset[0m[0;34m([0m[0;36m5[0m[0;34m,[0m [0;34m'processed_train.txt'[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [0;32m----> 2[0;31m [0mmodel[0m [0;34m=[0m [0mModel[0m[0;34m([0m[0mvocab_size[0m [0;34m=[0m [0mlen[0m[0;34m([0m[0mdataset[0m[0;34m.[0m[0muniq_words[0m[0;34m)[0m[0;34m)[0m[0;34m.[0m[0mto[0m[0;34m([0m[0mdevice[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m [0m [0;31mNameError[0m: name 'dataset' is not defined
model = Model(vocab_size = len(data.uniq_words)).to(device)
def train(dataset, model, max_epochs, batch_size):
model.train()
dataloader = DataLoader(dataset, batch_size=batch_size)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
for epoch in range(max_epochs):
for batch, (x, y) in enumerate(dataloader):
optimizer.zero_grad()
x = x.to(device)
y = y.to(device)
y_pred, (state_h, state_c) = model(x)
loss = criterion(y_pred.transpose(1, 2), y)
loss.backward()
optimizer.step()
print({ 'epoch': epoch, 'update in batch': batch, '/' : len(dataloader), 'loss': loss.item()
def predict(dataset, model, text, next_words=5):
model.eval()
words = text.split(' ')
state_h, state_c = model.init_state(len(words))
for i in range(0, next_words):
x = torch.tensor([[dataset.word_to_index[w] for w in words[i:]]]).to(device)
y_pred, (state_h, state_c) = model(x, (state_h, state_c))
last_word_logits = y_pred[0][-1]
p = torch.nn.functional.softmax(last_word_logits, dim=0).detach().cpu().numpy()
word_index = np.random.choice(len(last_word_logits), p=p)
words.append(dataset.index_to_word[word_index])
return words
#train(data, model, 1, 128)