Compare commits
No commits in common. "019" and "master" have entirely different histories.
1
.gitignore
vendored
1
.gitignore
vendored
@ -1,4 +1,3 @@
|
|||||||
*~
|
*~
|
||||||
*.swp
|
*.swp
|
||||||
*.o
|
*.o
|
||||||
venv/
|
|
||||||
|
40
src/Model.py
40
src/Model.py
@ -1,40 +0,0 @@
|
|||||||
import torch
|
|
||||||
|
|
||||||
class NgramModel(torch.nn.Module):
|
|
||||||
def __init__(self, vocab_size, n_hidden=256, n_layers=3, drop_prob=0.3, lr=0.001):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
self.drop_prob = drop_prob
|
|
||||||
self.n_hidden = n_hidden
|
|
||||||
self.n_layers = n_layers
|
|
||||||
self.lr = lr
|
|
||||||
self.vocab_size = vocab_size
|
|
||||||
|
|
||||||
self.embeddings = torch.nn.Embedding(self.vocab_size, 200)
|
|
||||||
|
|
||||||
self.rnn = torch.nn.RNN(200, self.n_hidden, self.n_layers, dropout = self.drop_prob, batch_first=True)
|
|
||||||
|
|
||||||
self.dropout = torch.nn.Dropout(self.drop_prob)
|
|
||||||
|
|
||||||
self.lin = torch.nn.Linear(self.n_hidden, self.vocab_size)
|
|
||||||
|
|
||||||
def forward(self, x, hidden):
|
|
||||||
embedded = self.embeddings(x)
|
|
||||||
|
|
||||||
output, hidden = self.rnn(embedded, hidden)
|
|
||||||
|
|
||||||
out = self.dropout(output)
|
|
||||||
out = out.reshape(-1, self.n_hidden)
|
|
||||||
|
|
||||||
out = self.lin(out)
|
|
||||||
return out, hidden
|
|
||||||
|
|
||||||
def init_hidden(self, batch_size):
|
|
||||||
weight = next(self.parameters()).data
|
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
hidden = weight.new(self.n_layers, batch_size, self.n_hidden).zero_().cuda()
|
|
||||||
else:
|
|
||||||
hidden = weight.new(self.n_layers, batch_size, self.n_hidden).zero_()
|
|
||||||
|
|
||||||
return hidden
|
|
197
src/train.py
197
src/train.py
@ -1,197 +0,0 @@
|
|||||||
#!/usr/bin/env python
|
|
||||||
|
|
||||||
print("Imports")
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import re
|
|
||||||
import pickle
|
|
||||||
import random
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
|
||||||
|
|
||||||
from Model import NgramModel
|
|
||||||
|
|
||||||
def clear_data(string):
|
|
||||||
return re.sub("[^a-z' ]", "", string)
|
|
||||||
|
|
||||||
def read_clear_data(in_file):
|
|
||||||
print("Reading data")
|
|
||||||
texts = []
|
|
||||||
with open(in_file) as f:
|
|
||||||
for line in f:
|
|
||||||
start_period, end_period, title, symbol, text = line.rstrip('\n').split('\t')
|
|
||||||
texts.append(text)
|
|
||||||
print("Data read")
|
|
||||||
return texts
|
|
||||||
|
|
||||||
def create_ngrams(string, ngram_len=2):
|
|
||||||
n_grams = []
|
|
||||||
if len(string.split()) > ngram_len:
|
|
||||||
for i in range(ngram_len, len(string.split())):
|
|
||||||
n_gram = string.split()[i-ngram_len:i+1]
|
|
||||||
n_grams.append(" ".join(n_gram))
|
|
||||||
return n_grams
|
|
||||||
return [string]
|
|
||||||
|
|
||||||
def get_ngrams(data, ngram_len=2):
|
|
||||||
print("Creating ngrams")
|
|
||||||
n_grams = []
|
|
||||||
counter = 0
|
|
||||||
for string in data:
|
|
||||||
n_grams.append(create_ngrams(string))
|
|
||||||
counter += 1
|
|
||||||
percentage = round((counter/len(data))*100, 2)
|
|
||||||
print(f"Status: {percentage}%", end='\r')
|
|
||||||
|
|
||||||
print("Creating one list")
|
|
||||||
n_grams = sum(n_grams, [])
|
|
||||||
print("Created ngrams")
|
|
||||||
return n_grams
|
|
||||||
|
|
||||||
def segment_data(n_grams):
|
|
||||||
print("Segmenting data")
|
|
||||||
source = []
|
|
||||||
target = []
|
|
||||||
|
|
||||||
for string in n_grams:
|
|
||||||
# tutaj brac pod uwage jescze follow slowa
|
|
||||||
source.append(" ".join(string.split()[:-1]))
|
|
||||||
target.append(" ".join(string.split()[1:]))
|
|
||||||
|
|
||||||
print("Data segmented")
|
|
||||||
return source, target
|
|
||||||
|
|
||||||
def create_vocab(data):
|
|
||||||
print("Creating vocab")
|
|
||||||
vocab = {}
|
|
||||||
counter = 0
|
|
||||||
|
|
||||||
for word in set(" ".join(data).split()):
|
|
||||||
vocab[counter] = word
|
|
||||||
counter += 1
|
|
||||||
percentage = round((counter/len(data))*100, 2)
|
|
||||||
print(f"Status: {percentage}%", end='\r')
|
|
||||||
|
|
||||||
vocab = {t:i for i,t in vocab.items()}
|
|
||||||
print("Vocab created")
|
|
||||||
return vocab
|
|
||||||
|
|
||||||
def segment_with_vocab(vocab, target, source):
|
|
||||||
print("Segmenting...")
|
|
||||||
def get_int_seq(seq):
|
|
||||||
return [vocab[word] for word in seq.split()]
|
|
||||||
|
|
||||||
source_int = [get_int_seq(i) for i in source]
|
|
||||||
target_int = [get_int_seq(i) for i in target]
|
|
||||||
|
|
||||||
source_int = np.array(source_int)
|
|
||||||
target_int = np.array(target_int)
|
|
||||||
|
|
||||||
print("Segmented")
|
|
||||||
return source_int, target_int
|
|
||||||
|
|
||||||
def get_batches(source_arr, target_arr, batch_size):
|
|
||||||
counter = 0
|
|
||||||
for n in range(batch_size, source_arr.shape[0], batch_size):
|
|
||||||
x = source_arr[counter:n,:]
|
|
||||||
y = target_arr[counter:n,:]
|
|
||||||
counter = n
|
|
||||||
yield x, y
|
|
||||||
|
|
||||||
def train(net, source_int, target_int, seed, epochs=5, batch_size=32, lr=0.001, clip=1, step=30):
|
|
||||||
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
|
|
||||||
criterion = torch.nn.CrossEntropyLoss()
|
|
||||||
|
|
||||||
counter = 0
|
|
||||||
|
|
||||||
print("Start training")
|
|
||||||
torch.autograd.set_detect_anomaly(True)
|
|
||||||
net.train()
|
|
||||||
for epoch in range(epochs):
|
|
||||||
hidden = net.init_hidden(batch_size)
|
|
||||||
|
|
||||||
#import ipdb;ipdb.set_trace()
|
|
||||||
for x,y in get_batches(source_int, target_int, batch_size):
|
|
||||||
counter +=1
|
|
||||||
|
|
||||||
source, target = torch.from_numpy(x), torch.from_numpy(y)
|
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
source = source.cuda()
|
|
||||||
target = target.cuda()
|
|
||||||
|
|
||||||
#hidden = tuple([each.data for each in hidden])
|
|
||||||
|
|
||||||
net.zero_grad()
|
|
||||||
|
|
||||||
output, hidden = net(source, hidden)
|
|
||||||
hidden.detach_()
|
|
||||||
|
|
||||||
loss = criterion(output, target.view(-1))
|
|
||||||
|
|
||||||
#if counter == 1:
|
|
||||||
# loss.backward(retain_graph=True)
|
|
||||||
#else:
|
|
||||||
# loss.backward()
|
|
||||||
loss.backward()
|
|
||||||
|
|
||||||
torch.nn.utils.clip_grad_norm_(net.parameters(), clip)
|
|
||||||
optimizer.step()
|
|
||||||
|
|
||||||
if counter % step == 0:
|
|
||||||
print(f"Epoch: {epoch}/{epochs} ; Step : {counter} ; loss : {loss}")
|
|
||||||
|
|
||||||
if counter % 500 == 0:
|
|
||||||
torch.save(net.state_dict(), f"checkpoint.ckpt-{counter}-epoch_{epoch}-seed_{seed}")
|
|
||||||
|
|
||||||
def main():
|
|
||||||
parser = argparse.ArgumentParser()
|
|
||||||
parser.add_argument('--in_file')
|
|
||||||
parser.add_argument('--ngram_level', default=2, help="Level of ngram")
|
|
||||||
parser.add_argument('--ngrams', help="Path to pickle with ready bigrams")
|
|
||||||
parser.add_argument('--vocab')
|
|
||||||
parser.add_argument('--model')
|
|
||||||
args = parser.parse_args()
|
|
||||||
seed = random.randint(0, 20)
|
|
||||||
|
|
||||||
if args.ngrams:
|
|
||||||
print("Reading ngrams")
|
|
||||||
with open(args.ngrams, 'rb') as f:
|
|
||||||
source, target, data, n_grams = pickle.load(f)
|
|
||||||
print("Ngrams read")
|
|
||||||
else:
|
|
||||||
data = read_clear_data(args.in_file)
|
|
||||||
n_grams = get_ngrams(data, args.ngram_level)
|
|
||||||
source, target = segment_data(n_grams)
|
|
||||||
print("Saving progress...")
|
|
||||||
with open(f"n_grams-ngram_{args.ngram_level}-seed_{seed}.pickle", 'wb+') as f:
|
|
||||||
pickle.dump((source, target, data, n_grams), f)
|
|
||||||
print(f"Saved: n_grams-ngram_{args.ngram_level}-seed_{seed}.pickle")
|
|
||||||
|
|
||||||
if args.vocab:
|
|
||||||
print("Reading vocab")
|
|
||||||
with open(args.vocab, 'rb') as f:
|
|
||||||
vocab, source_int, target_int = pickle.load(f)
|
|
||||||
print("Vocab read")
|
|
||||||
else:
|
|
||||||
vocab = create_vocab(data)
|
|
||||||
print(f"Vocab size: {len(vocab)}")
|
|
||||||
source_int, target_int = segment_with_vocab(vocab, target, source)
|
|
||||||
print("Saving progress")
|
|
||||||
with open(f"vocab-seed_{seed}.pickle", 'wb+') as f:
|
|
||||||
pickle.dump((vocab, source_int, target_int), f)
|
|
||||||
print(f"Saved: vocab-seed_{seed}.pickle")
|
|
||||||
|
|
||||||
vocab_size = len(vocab)
|
|
||||||
|
|
||||||
net = NgramModel(vocab_size=vocab_size)
|
|
||||||
|
|
||||||
if args.model:
|
|
||||||
net.load_state_dict(torch.load(args.model))
|
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
net.cuda()
|
|
||||||
|
|
||||||
train(net, source_int, target_int, seed)
|
|
||||||
|
|
||||||
main()
|
|
Loading…
Reference in New Issue
Block a user