Compare commits

..

4 Commits
master ... 019

Author SHA1 Message Date
SzamanFL
ee23fd9d0f Fix in train 2021-01-08 20:17:07 +01:00
SzamanFL
bdc1e902e8 Fix a little 2021-01-08 18:12:58 +01:00
SzamanFL
0db9211817 Added model and train script 2021-01-07 23:29:15 +01:00
SzamanFL
d7c0e53b89 Update gitignore 2021-01-07 23:28:53 +01:00
3 changed files with 238 additions and 0 deletions

1
.gitignore vendored
View File

@ -1,3 +1,4 @@
*~
*.swp
*.o
venv/

40
src/Model.py Normal file
View File

@ -0,0 +1,40 @@
import torch
class NgramModel(torch.nn.Module):
def __init__(self, vocab_size, n_hidden=256, n_layers=3, drop_prob=0.3, lr=0.001):
super().__init__()
self.drop_prob = drop_prob
self.n_hidden = n_hidden
self.n_layers = n_layers
self.lr = lr
self.vocab_size = vocab_size
self.embeddings = torch.nn.Embedding(self.vocab_size, 200)
self.rnn = torch.nn.RNN(200, self.n_hidden, self.n_layers, dropout = self.drop_prob, batch_first=True)
self.dropout = torch.nn.Dropout(self.drop_prob)
self.lin = torch.nn.Linear(self.n_hidden, self.vocab_size)
def forward(self, x, hidden):
embedded = self.embeddings(x)
output, hidden = self.rnn(embedded, hidden)
out = self.dropout(output)
out = out.reshape(-1, self.n_hidden)
out = self.lin(out)
return out, hidden
def init_hidden(self, batch_size):
weight = next(self.parameters()).data
if torch.cuda.is_available():
hidden = weight.new(self.n_layers, batch_size, self.n_hidden).zero_().cuda()
else:
hidden = weight.new(self.n_layers, batch_size, self.n_hidden).zero_()
return hidden

197
src/train.py Normal file
View File

@ -0,0 +1,197 @@
#!/usr/bin/env python
print("Imports")
import argparse
import re
import pickle
import random
import numpy as np
import torch
from Model import NgramModel
def clear_data(string):
return re.sub("[^a-z' ]", "", string)
def read_clear_data(in_file):
print("Reading data")
texts = []
with open(in_file) as f:
for line in f:
start_period, end_period, title, symbol, text = line.rstrip('\n').split('\t')
texts.append(text)
print("Data read")
return texts
def create_ngrams(string, ngram_len=2):
n_grams = []
if len(string.split()) > ngram_len:
for i in range(ngram_len, len(string.split())):
n_gram = string.split()[i-ngram_len:i+1]
n_grams.append(" ".join(n_gram))
return n_grams
return [string]
def get_ngrams(data, ngram_len=2):
print("Creating ngrams")
n_grams = []
counter = 0
for string in data:
n_grams.append(create_ngrams(string))
counter += 1
percentage = round((counter/len(data))*100, 2)
print(f"Status: {percentage}%", end='\r')
print("Creating one list")
n_grams = sum(n_grams, [])
print("Created ngrams")
return n_grams
def segment_data(n_grams):
print("Segmenting data")
source = []
target = []
for string in n_grams:
# tutaj brac pod uwage jescze follow slowa
source.append(" ".join(string.split()[:-1]))
target.append(" ".join(string.split()[1:]))
print("Data segmented")
return source, target
def create_vocab(data):
print("Creating vocab")
vocab = {}
counter = 0
for word in set(" ".join(data).split()):
vocab[counter] = word
counter += 1
percentage = round((counter/len(data))*100, 2)
print(f"Status: {percentage}%", end='\r')
vocab = {t:i for i,t in vocab.items()}
print("Vocab created")
return vocab
def segment_with_vocab(vocab, target, source):
print("Segmenting...")
def get_int_seq(seq):
return [vocab[word] for word in seq.split()]
source_int = [get_int_seq(i) for i in source]
target_int = [get_int_seq(i) for i in target]
source_int = np.array(source_int)
target_int = np.array(target_int)
print("Segmented")
return source_int, target_int
def get_batches(source_arr, target_arr, batch_size):
counter = 0
for n in range(batch_size, source_arr.shape[0], batch_size):
x = source_arr[counter:n,:]
y = target_arr[counter:n,:]
counter = n
yield x, y
def train(net, source_int, target_int, seed, epochs=5, batch_size=32, lr=0.001, clip=1, step=30):
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
criterion = torch.nn.CrossEntropyLoss()
counter = 0
print("Start training")
torch.autograd.set_detect_anomaly(True)
net.train()
for epoch in range(epochs):
hidden = net.init_hidden(batch_size)
#import ipdb;ipdb.set_trace()
for x,y in get_batches(source_int, target_int, batch_size):
counter +=1
source, target = torch.from_numpy(x), torch.from_numpy(y)
if torch.cuda.is_available():
source = source.cuda()
target = target.cuda()
#hidden = tuple([each.data for each in hidden])
net.zero_grad()
output, hidden = net(source, hidden)
hidden.detach_()
loss = criterion(output, target.view(-1))
#if counter == 1:
# loss.backward(retain_graph=True)
#else:
# loss.backward()
loss.backward()
torch.nn.utils.clip_grad_norm_(net.parameters(), clip)
optimizer.step()
if counter % step == 0:
print(f"Epoch: {epoch}/{epochs} ; Step : {counter} ; loss : {loss}")
if counter % 500 == 0:
torch.save(net.state_dict(), f"checkpoint.ckpt-{counter}-epoch_{epoch}-seed_{seed}")
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--in_file')
parser.add_argument('--ngram_level', default=2, help="Level of ngram")
parser.add_argument('--ngrams', help="Path to pickle with ready bigrams")
parser.add_argument('--vocab')
parser.add_argument('--model')
args = parser.parse_args()
seed = random.randint(0, 20)
if args.ngrams:
print("Reading ngrams")
with open(args.ngrams, 'rb') as f:
source, target, data, n_grams = pickle.load(f)
print("Ngrams read")
else:
data = read_clear_data(args.in_file)
n_grams = get_ngrams(data, args.ngram_level)
source, target = segment_data(n_grams)
print("Saving progress...")
with open(f"n_grams-ngram_{args.ngram_level}-seed_{seed}.pickle", 'wb+') as f:
pickle.dump((source, target, data, n_grams), f)
print(f"Saved: n_grams-ngram_{args.ngram_level}-seed_{seed}.pickle")
if args.vocab:
print("Reading vocab")
with open(args.vocab, 'rb') as f:
vocab, source_int, target_int = pickle.load(f)
print("Vocab read")
else:
vocab = create_vocab(data)
print(f"Vocab size: {len(vocab)}")
source_int, target_int = segment_with_vocab(vocab, target, source)
print("Saving progress")
with open(f"vocab-seed_{seed}.pickle", 'wb+') as f:
pickle.dump((vocab, source_int, target_int), f)
print(f"Saved: vocab-seed_{seed}.pickle")
vocab_size = len(vocab)
net = NgramModel(vocab_size=vocab_size)
if args.model:
net.load_state_dict(torch.load(args.model))
if torch.cuda.is_available():
net.cuda()
train(net, source_int, target_int, seed)
main()