From 0db921181789f4c7df22eca648f5be46c925ca4d Mon Sep 17 00:00:00 2001 From: SzamanFL Date: Thu, 7 Jan 2021 23:29:15 +0100 Subject: [PATCH] Added model and train script --- src/Model.py | 40 ++++++++++++ src/train.py | 180 +++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 220 insertions(+) create mode 100644 src/Model.py create mode 100644 src/train.py diff --git a/src/Model.py b/src/Model.py new file mode 100644 index 0000000..a2c6010 --- /dev/null +++ b/src/Model.py @@ -0,0 +1,40 @@ +import torch + +class NgramModel(torch.nn.Module): + def __init__(self, n_hidden=256, n_layers=3, drop_prob=0.3, lr=0.001, vocab_size): + super().__init__() + + self.drop_prob = drop_prob + self.n_hidden = n_hidden + self.n_layers = n_layers + self.lr = lr + self.vocab_size = vocab_size + + self.embeddings = torch.nn.Embedding(self.vocab_size, 200) + + self.rnn = torch.nn.RNN(200, self.n_hidden, self.n_layers, dropout = self.drop_prob, batch_first=True) + + self.dropout = torch.nn.Dropout(self.drop_prob) + + self.lin = torch.nn.Linear(self.n_hidden, self.vocab_size) + + def forward(self, x, hidden): + embedded = self.embeddings(x) + + output, hidden = self.rnn(embedded, hidden) + + out = self.dropout(output) + out = out.reshape(-1, self.n_hidden) + + out = self.lin(out) + return out, hidden + + def init_hidden(self, batch_size): + weight = next(self.parameters()).data + + if torch.cuda.is_available(): + hidden = (weight.new(self.n_layers, batch_size, self.n_hidden).zero_().cuda(),weight.new(self.n_layers, batch_size, self.n_hidden).zero_().cuda()) + else: + hidden = (weight.new(self.n_layers, batch_size, self.n_hidden).zero_(),weight.new(self.n_layers, batch_size, self.n_hidden).zero_()) + + return hidden diff --git a/src/train.py b/src/train.py new file mode 100644 index 0000000..c98f796 --- /dev/null +++ b/src/train.py @@ -0,0 +1,180 @@ +#!/usr/bin/env python + +import argparse +import re +import pickle +import random +import numpy as np +import torch + +from Model import NgramModel + +def clear_data(string): + return re.sub("[^a-z' ]", "", string) + +def read_clear_data(in_file): + print("Reading data") + texts = [] + with open(in_file) as f: + for line in f: + start_period, end_period, title, symbol, text = line.rstrip('\n').split('\t') + texts.append(text) + print("Data read") + return texts + +def create_ngrams(string, ngram_len=2): + print("Creating ngrams") + n_grams = [] + if len(string.split()) > ngram_len: + for i in range(ngram_len, len(string.split())): + n_gram = string.split()[i-ngram_len:i+1] + n_grams.append(" ".join(n_gram)) + return n_grams + return [string] + +def get_ngrams(data, ngram_len=2): + print("Creating ngrams") + n_grams = [] + for string in data: + n_grams.append(create_ngrams(sring)) + + n_grams = sum(n_grams, []) + print("Created ngrams") + return n_grams + +def segment_data(n_grams): + print("Segmenting data") + source = [] + target = [] + + for string in n_grams: + source.append(" ".join(string.split()[:-1])) + target.append(" ".join(string.split()[1:])) + + print("Data segmented") + return source, target + +def create_vocab(data): + print("Creating vocab") + vocab = {} + counter = 0 + + for word in set(" ".join(data).split()): + vocab[counter] = word + counter += 1 + + vocab = {t:i for i,t in vocab.items()} + print("Vocab created") + return vocab + +def segment_with_vocab(vocab, target, source): + print("Segmenting...") + def get_int_seq(seq): + return [vocab[word] for word in seq.split()] + + source_int = [get_int_seq(i) for i in source] + target_int = [get_int_seq(i) for i in target] + + source_int = np.array(source_int) + target_int = np.array(target_int) + + print("Segmented") + return source_int, target_int + +def get_batches(source_arr, target_arr, batch_size): + counter = 0 + for n in range(batch_size, source_arr.shape[0], batch_size): + x = source_arr[counter:n,:] + y = target_arr[counter:n,:] + counter = n + yield x, y + +def train(net, source_int, target_int, seed, epochs=5, batch_size=32, lr=0.001, clip=1, step=30): + optimizer = torch.optim.Adam(net.parameters(), lr=lr) + criterion = torch.nn.CrossEntropyLoss() + + counter = 0 + + net.train() + + for epoch in range(epochs): + h = net.init_hidden(batch_size) + + for x,y in get_batches(source_int, target_int, batch_size): + counter +=1 + + source, target = torch.from_numpy(x), torch.from_numpy(y) + + if torch.cuda.is_available(): + source = source.cuda() + target = target.cuda() + + h = tuple([each.data for each in h]) + + net.zero_grad() + + output, h = net(source, h) + + loss = criterion(output, target.view(-1)) + + loss.backward() + + nn.utils.clip_grad_norm_(net.parameters(), clip) + optimizer.step() + + if counter % step == 0: + print(f"Epoch: {epoch}/{epochs} ; Step : {counter}") + + if counter % 500 == 0: + torch.save(net.state_dict(), f"checkpoint.ckpt-{counter}-epoch_{epoch}-seed_{seed}") + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--in_file', '') + parser.add_argument('--ngram_level', default=2, help="Level of ngram") + parser.add_argument('--ngrams', help="Path to pickle with ready bigrams") + parser.add_argument('--vocab') + parser.add_argument('--model') + args = parser.parse_args() + seed = random.randint(0, 20) + + if args.ngrams: + print("Reading ngrams") + with open(args.ngrams, 'rb') as f: + source, target, data = pickle.load(f) + print("Ngrams read") + else: + data = read_clear_data(args.in_file) + n_grams = get_ngrams(data, args.ngram_level) + source, target = segment_data(n_grams) + print("Saving progress...") + with open(f"n_grams-ngram_{ngram_level}-seed_{seed}.pickle", 'wb+') as f: + pickle.dump((source, target, data), f) + print(f"Saved: n_grams-ngram_{ngram_level}-seed_{seed}.pickle") + + if args.vocab: + print("Reading vocab") + with open(args.vocab, 'rb') as f: + vocab, source_int, target_int = pickle.load(f) + print("Vocab read") + else: + vocab = create_vocab(data) + source_int, target_int = segment_with_vocab(vocab, target, source) + print("Saving progress") + with open(f"vocab-seed_{seed}.pickle", 'wb+') as f: + pickle.dump((vocab, source_int, target_int), f) + print(f"Saved: vocab-seed_{seed}.pickle") + + vocab_size = len(vocab) + + net = NgramModel(vocab_size=vocab_size) + + if args.model: + net.load_state_dict(torch.load(args.model)) + + if torch.cuda.is_available(): + net.cuda() + + train(net, source_int, target_int, seed) + +main()