Added model and train script
This commit is contained in:
parent
d7c0e53b89
commit
0db9211817
40
src/Model.py
Normal file
40
src/Model.py
Normal file
@ -0,0 +1,40 @@
|
||||
import torch
|
||||
|
||||
class NgramModel(torch.nn.Module):
|
||||
def __init__(self, n_hidden=256, n_layers=3, drop_prob=0.3, lr=0.001, vocab_size):
|
||||
super().__init__()
|
||||
|
||||
self.drop_prob = drop_prob
|
||||
self.n_hidden = n_hidden
|
||||
self.n_layers = n_layers
|
||||
self.lr = lr
|
||||
self.vocab_size = vocab_size
|
||||
|
||||
self.embeddings = torch.nn.Embedding(self.vocab_size, 200)
|
||||
|
||||
self.rnn = torch.nn.RNN(200, self.n_hidden, self.n_layers, dropout = self.drop_prob, batch_first=True)
|
||||
|
||||
self.dropout = torch.nn.Dropout(self.drop_prob)
|
||||
|
||||
self.lin = torch.nn.Linear(self.n_hidden, self.vocab_size)
|
||||
|
||||
def forward(self, x, hidden):
|
||||
embedded = self.embeddings(x)
|
||||
|
||||
output, hidden = self.rnn(embedded, hidden)
|
||||
|
||||
out = self.dropout(output)
|
||||
out = out.reshape(-1, self.n_hidden)
|
||||
|
||||
out = self.lin(out)
|
||||
return out, hidden
|
||||
|
||||
def init_hidden(self, batch_size):
|
||||
weight = next(self.parameters()).data
|
||||
|
||||
if torch.cuda.is_available():
|
||||
hidden = (weight.new(self.n_layers, batch_size, self.n_hidden).zero_().cuda(),weight.new(self.n_layers, batch_size, self.n_hidden).zero_().cuda())
|
||||
else:
|
||||
hidden = (weight.new(self.n_layers, batch_size, self.n_hidden).zero_(),weight.new(self.n_layers, batch_size, self.n_hidden).zero_())
|
||||
|
||||
return hidden
|
180
src/train.py
Normal file
180
src/train.py
Normal file
@ -0,0 +1,180 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import argparse
|
||||
import re
|
||||
import pickle
|
||||
import random
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from Model import NgramModel
|
||||
|
||||
def clear_data(string):
|
||||
return re.sub("[^a-z' ]", "", string)
|
||||
|
||||
def read_clear_data(in_file):
|
||||
print("Reading data")
|
||||
texts = []
|
||||
with open(in_file) as f:
|
||||
for line in f:
|
||||
start_period, end_period, title, symbol, text = line.rstrip('\n').split('\t')
|
||||
texts.append(text)
|
||||
print("Data read")
|
||||
return texts
|
||||
|
||||
def create_ngrams(string, ngram_len=2):
|
||||
print("Creating ngrams")
|
||||
n_grams = []
|
||||
if len(string.split()) > ngram_len:
|
||||
for i in range(ngram_len, len(string.split())):
|
||||
n_gram = string.split()[i-ngram_len:i+1]
|
||||
n_grams.append(" ".join(n_gram))
|
||||
return n_grams
|
||||
return [string]
|
||||
|
||||
def get_ngrams(data, ngram_len=2):
|
||||
print("Creating ngrams")
|
||||
n_grams = []
|
||||
for string in data:
|
||||
n_grams.append(create_ngrams(sring))
|
||||
|
||||
n_grams = sum(n_grams, [])
|
||||
print("Created ngrams")
|
||||
return n_grams
|
||||
|
||||
def segment_data(n_grams):
|
||||
print("Segmenting data")
|
||||
source = []
|
||||
target = []
|
||||
|
||||
for string in n_grams:
|
||||
source.append(" ".join(string.split()[:-1]))
|
||||
target.append(" ".join(string.split()[1:]))
|
||||
|
||||
print("Data segmented")
|
||||
return source, target
|
||||
|
||||
def create_vocab(data):
|
||||
print("Creating vocab")
|
||||
vocab = {}
|
||||
counter = 0
|
||||
|
||||
for word in set(" ".join(data).split()):
|
||||
vocab[counter] = word
|
||||
counter += 1
|
||||
|
||||
vocab = {t:i for i,t in vocab.items()}
|
||||
print("Vocab created")
|
||||
return vocab
|
||||
|
||||
def segment_with_vocab(vocab, target, source):
|
||||
print("Segmenting...")
|
||||
def get_int_seq(seq):
|
||||
return [vocab[word] for word in seq.split()]
|
||||
|
||||
source_int = [get_int_seq(i) for i in source]
|
||||
target_int = [get_int_seq(i) for i in target]
|
||||
|
||||
source_int = np.array(source_int)
|
||||
target_int = np.array(target_int)
|
||||
|
||||
print("Segmented")
|
||||
return source_int, target_int
|
||||
|
||||
def get_batches(source_arr, target_arr, batch_size):
|
||||
counter = 0
|
||||
for n in range(batch_size, source_arr.shape[0], batch_size):
|
||||
x = source_arr[counter:n,:]
|
||||
y = target_arr[counter:n,:]
|
||||
counter = n
|
||||
yield x, y
|
||||
|
||||
def train(net, source_int, target_int, seed, epochs=5, batch_size=32, lr=0.001, clip=1, step=30):
|
||||
optimizer = torch.optim.Adam(net.parameters(), lr=lr)
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
|
||||
counter = 0
|
||||
|
||||
net.train()
|
||||
|
||||
for epoch in range(epochs):
|
||||
h = net.init_hidden(batch_size)
|
||||
|
||||
for x,y in get_batches(source_int, target_int, batch_size):
|
||||
counter +=1
|
||||
|
||||
source, target = torch.from_numpy(x), torch.from_numpy(y)
|
||||
|
||||
if torch.cuda.is_available():
|
||||
source = source.cuda()
|
||||
target = target.cuda()
|
||||
|
||||
h = tuple([each.data for each in h])
|
||||
|
||||
net.zero_grad()
|
||||
|
||||
output, h = net(source, h)
|
||||
|
||||
loss = criterion(output, target.view(-1))
|
||||
|
||||
loss.backward()
|
||||
|
||||
nn.utils.clip_grad_norm_(net.parameters(), clip)
|
||||
optimizer.step()
|
||||
|
||||
if counter % step == 0:
|
||||
print(f"Epoch: {epoch}/{epochs} ; Step : {counter}")
|
||||
|
||||
if counter % 500 == 0:
|
||||
torch.save(net.state_dict(), f"checkpoint.ckpt-{counter}-epoch_{epoch}-seed_{seed}")
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--in_file', '')
|
||||
parser.add_argument('--ngram_level', default=2, help="Level of ngram")
|
||||
parser.add_argument('--ngrams', help="Path to pickle with ready bigrams")
|
||||
parser.add_argument('--vocab')
|
||||
parser.add_argument('--model')
|
||||
args = parser.parse_args()
|
||||
seed = random.randint(0, 20)
|
||||
|
||||
if args.ngrams:
|
||||
print("Reading ngrams")
|
||||
with open(args.ngrams, 'rb') as f:
|
||||
source, target, data = pickle.load(f)
|
||||
print("Ngrams read")
|
||||
else:
|
||||
data = read_clear_data(args.in_file)
|
||||
n_grams = get_ngrams(data, args.ngram_level)
|
||||
source, target = segment_data(n_grams)
|
||||
print("Saving progress...")
|
||||
with open(f"n_grams-ngram_{ngram_level}-seed_{seed}.pickle", 'wb+') as f:
|
||||
pickle.dump((source, target, data), f)
|
||||
print(f"Saved: n_grams-ngram_{ngram_level}-seed_{seed}.pickle")
|
||||
|
||||
if args.vocab:
|
||||
print("Reading vocab")
|
||||
with open(args.vocab, 'rb') as f:
|
||||
vocab, source_int, target_int = pickle.load(f)
|
||||
print("Vocab read")
|
||||
else:
|
||||
vocab = create_vocab(data)
|
||||
source_int, target_int = segment_with_vocab(vocab, target, source)
|
||||
print("Saving progress")
|
||||
with open(f"vocab-seed_{seed}.pickle", 'wb+') as f:
|
||||
pickle.dump((vocab, source_int, target_int), f)
|
||||
print(f"Saved: vocab-seed_{seed}.pickle")
|
||||
|
||||
vocab_size = len(vocab)
|
||||
|
||||
net = NgramModel(vocab_size=vocab_size)
|
||||
|
||||
if args.model:
|
||||
net.load_state_dict(torch.load(args.model))
|
||||
|
||||
if torch.cuda.is_available():
|
||||
net.cuda()
|
||||
|
||||
train(net, source_int, target_int, seed)
|
||||
|
||||
main()
|
Loading…
Reference in New Issue
Block a user