Compare commits
4 Commits
Author | SHA1 | Date | |
---|---|---|---|
|
ee23fd9d0f | ||
|
bdc1e902e8 | ||
|
0db9211817 | ||
|
d7c0e53b89 |
1
.gitignore
vendored
1
.gitignore
vendored
@ -1,3 +1,4 @@
|
|||||||
*~
|
*~
|
||||||
*.swp
|
*.swp
|
||||||
*.o
|
*.o
|
||||||
|
venv/
|
||||||
|
40
src/Model.py
Normal file
40
src/Model.py
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
import torch
|
||||||
|
|
||||||
|
class NgramModel(torch.nn.Module):
|
||||||
|
def __init__(self, vocab_size, n_hidden=256, n_layers=3, drop_prob=0.3, lr=0.001):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self.drop_prob = drop_prob
|
||||||
|
self.n_hidden = n_hidden
|
||||||
|
self.n_layers = n_layers
|
||||||
|
self.lr = lr
|
||||||
|
self.vocab_size = vocab_size
|
||||||
|
|
||||||
|
self.embeddings = torch.nn.Embedding(self.vocab_size, 200)
|
||||||
|
|
||||||
|
self.rnn = torch.nn.RNN(200, self.n_hidden, self.n_layers, dropout = self.drop_prob, batch_first=True)
|
||||||
|
|
||||||
|
self.dropout = torch.nn.Dropout(self.drop_prob)
|
||||||
|
|
||||||
|
self.lin = torch.nn.Linear(self.n_hidden, self.vocab_size)
|
||||||
|
|
||||||
|
def forward(self, x, hidden):
|
||||||
|
embedded = self.embeddings(x)
|
||||||
|
|
||||||
|
output, hidden = self.rnn(embedded, hidden)
|
||||||
|
|
||||||
|
out = self.dropout(output)
|
||||||
|
out = out.reshape(-1, self.n_hidden)
|
||||||
|
|
||||||
|
out = self.lin(out)
|
||||||
|
return out, hidden
|
||||||
|
|
||||||
|
def init_hidden(self, batch_size):
|
||||||
|
weight = next(self.parameters()).data
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
hidden = weight.new(self.n_layers, batch_size, self.n_hidden).zero_().cuda()
|
||||||
|
else:
|
||||||
|
hidden = weight.new(self.n_layers, batch_size, self.n_hidden).zero_()
|
||||||
|
|
||||||
|
return hidden
|
197
src/train.py
Normal file
197
src/train.py
Normal file
@ -0,0 +1,197 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
print("Imports")
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import re
|
||||||
|
import pickle
|
||||||
|
import random
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from Model import NgramModel
|
||||||
|
|
||||||
|
def clear_data(string):
|
||||||
|
return re.sub("[^a-z' ]", "", string)
|
||||||
|
|
||||||
|
def read_clear_data(in_file):
|
||||||
|
print("Reading data")
|
||||||
|
texts = []
|
||||||
|
with open(in_file) as f:
|
||||||
|
for line in f:
|
||||||
|
start_period, end_period, title, symbol, text = line.rstrip('\n').split('\t')
|
||||||
|
texts.append(text)
|
||||||
|
print("Data read")
|
||||||
|
return texts
|
||||||
|
|
||||||
|
def create_ngrams(string, ngram_len=2):
|
||||||
|
n_grams = []
|
||||||
|
if len(string.split()) > ngram_len:
|
||||||
|
for i in range(ngram_len, len(string.split())):
|
||||||
|
n_gram = string.split()[i-ngram_len:i+1]
|
||||||
|
n_grams.append(" ".join(n_gram))
|
||||||
|
return n_grams
|
||||||
|
return [string]
|
||||||
|
|
||||||
|
def get_ngrams(data, ngram_len=2):
|
||||||
|
print("Creating ngrams")
|
||||||
|
n_grams = []
|
||||||
|
counter = 0
|
||||||
|
for string in data:
|
||||||
|
n_grams.append(create_ngrams(string))
|
||||||
|
counter += 1
|
||||||
|
percentage = round((counter/len(data))*100, 2)
|
||||||
|
print(f"Status: {percentage}%", end='\r')
|
||||||
|
|
||||||
|
print("Creating one list")
|
||||||
|
n_grams = sum(n_grams, [])
|
||||||
|
print("Created ngrams")
|
||||||
|
return n_grams
|
||||||
|
|
||||||
|
def segment_data(n_grams):
|
||||||
|
print("Segmenting data")
|
||||||
|
source = []
|
||||||
|
target = []
|
||||||
|
|
||||||
|
for string in n_grams:
|
||||||
|
# tutaj brac pod uwage jescze follow slowa
|
||||||
|
source.append(" ".join(string.split()[:-1]))
|
||||||
|
target.append(" ".join(string.split()[1:]))
|
||||||
|
|
||||||
|
print("Data segmented")
|
||||||
|
return source, target
|
||||||
|
|
||||||
|
def create_vocab(data):
|
||||||
|
print("Creating vocab")
|
||||||
|
vocab = {}
|
||||||
|
counter = 0
|
||||||
|
|
||||||
|
for word in set(" ".join(data).split()):
|
||||||
|
vocab[counter] = word
|
||||||
|
counter += 1
|
||||||
|
percentage = round((counter/len(data))*100, 2)
|
||||||
|
print(f"Status: {percentage}%", end='\r')
|
||||||
|
|
||||||
|
vocab = {t:i for i,t in vocab.items()}
|
||||||
|
print("Vocab created")
|
||||||
|
return vocab
|
||||||
|
|
||||||
|
def segment_with_vocab(vocab, target, source):
|
||||||
|
print("Segmenting...")
|
||||||
|
def get_int_seq(seq):
|
||||||
|
return [vocab[word] for word in seq.split()]
|
||||||
|
|
||||||
|
source_int = [get_int_seq(i) for i in source]
|
||||||
|
target_int = [get_int_seq(i) for i in target]
|
||||||
|
|
||||||
|
source_int = np.array(source_int)
|
||||||
|
target_int = np.array(target_int)
|
||||||
|
|
||||||
|
print("Segmented")
|
||||||
|
return source_int, target_int
|
||||||
|
|
||||||
|
def get_batches(source_arr, target_arr, batch_size):
|
||||||
|
counter = 0
|
||||||
|
for n in range(batch_size, source_arr.shape[0], batch_size):
|
||||||
|
x = source_arr[counter:n,:]
|
||||||
|
y = target_arr[counter:n,:]
|
||||||
|
counter = n
|
||||||
|
yield x, y
|
||||||
|
|
||||||
|
def train(net, source_int, target_int, seed, epochs=5, batch_size=32, lr=0.001, clip=1, step=30):
|
||||||
|
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
|
||||||
|
criterion = torch.nn.CrossEntropyLoss()
|
||||||
|
|
||||||
|
counter = 0
|
||||||
|
|
||||||
|
print("Start training")
|
||||||
|
torch.autograd.set_detect_anomaly(True)
|
||||||
|
net.train()
|
||||||
|
for epoch in range(epochs):
|
||||||
|
hidden = net.init_hidden(batch_size)
|
||||||
|
|
||||||
|
#import ipdb;ipdb.set_trace()
|
||||||
|
for x,y in get_batches(source_int, target_int, batch_size):
|
||||||
|
counter +=1
|
||||||
|
|
||||||
|
source, target = torch.from_numpy(x), torch.from_numpy(y)
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
source = source.cuda()
|
||||||
|
target = target.cuda()
|
||||||
|
|
||||||
|
#hidden = tuple([each.data for each in hidden])
|
||||||
|
|
||||||
|
net.zero_grad()
|
||||||
|
|
||||||
|
output, hidden = net(source, hidden)
|
||||||
|
hidden.detach_()
|
||||||
|
|
||||||
|
loss = criterion(output, target.view(-1))
|
||||||
|
|
||||||
|
#if counter == 1:
|
||||||
|
# loss.backward(retain_graph=True)
|
||||||
|
#else:
|
||||||
|
# loss.backward()
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
torch.nn.utils.clip_grad_norm_(net.parameters(), clip)
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
if counter % step == 0:
|
||||||
|
print(f"Epoch: {epoch}/{epochs} ; Step : {counter} ; loss : {loss}")
|
||||||
|
|
||||||
|
if counter % 500 == 0:
|
||||||
|
torch.save(net.state_dict(), f"checkpoint.ckpt-{counter}-epoch_{epoch}-seed_{seed}")
|
||||||
|
|
||||||
|
def main():
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument('--in_file')
|
||||||
|
parser.add_argument('--ngram_level', default=2, help="Level of ngram")
|
||||||
|
parser.add_argument('--ngrams', help="Path to pickle with ready bigrams")
|
||||||
|
parser.add_argument('--vocab')
|
||||||
|
parser.add_argument('--model')
|
||||||
|
args = parser.parse_args()
|
||||||
|
seed = random.randint(0, 20)
|
||||||
|
|
||||||
|
if args.ngrams:
|
||||||
|
print("Reading ngrams")
|
||||||
|
with open(args.ngrams, 'rb') as f:
|
||||||
|
source, target, data, n_grams = pickle.load(f)
|
||||||
|
print("Ngrams read")
|
||||||
|
else:
|
||||||
|
data = read_clear_data(args.in_file)
|
||||||
|
n_grams = get_ngrams(data, args.ngram_level)
|
||||||
|
source, target = segment_data(n_grams)
|
||||||
|
print("Saving progress...")
|
||||||
|
with open(f"n_grams-ngram_{args.ngram_level}-seed_{seed}.pickle", 'wb+') as f:
|
||||||
|
pickle.dump((source, target, data, n_grams), f)
|
||||||
|
print(f"Saved: n_grams-ngram_{args.ngram_level}-seed_{seed}.pickle")
|
||||||
|
|
||||||
|
if args.vocab:
|
||||||
|
print("Reading vocab")
|
||||||
|
with open(args.vocab, 'rb') as f:
|
||||||
|
vocab, source_int, target_int = pickle.load(f)
|
||||||
|
print("Vocab read")
|
||||||
|
else:
|
||||||
|
vocab = create_vocab(data)
|
||||||
|
print(f"Vocab size: {len(vocab)}")
|
||||||
|
source_int, target_int = segment_with_vocab(vocab, target, source)
|
||||||
|
print("Saving progress")
|
||||||
|
with open(f"vocab-seed_{seed}.pickle", 'wb+') as f:
|
||||||
|
pickle.dump((vocab, source_int, target_int), f)
|
||||||
|
print(f"Saved: vocab-seed_{seed}.pickle")
|
||||||
|
|
||||||
|
vocab_size = len(vocab)
|
||||||
|
|
||||||
|
net = NgramModel(vocab_size=vocab_size)
|
||||||
|
|
||||||
|
if args.model:
|
||||||
|
net.load_state_dict(torch.load(args.model))
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
net.cuda()
|
||||||
|
|
||||||
|
train(net, source_int, target_int, seed)
|
||||||
|
|
||||||
|
main()
|
Loading…
Reference in New Issue
Block a user