diff --git a/src/Model.py b/src/Model.py index 632ec4b..1fa5283 100644 --- a/src/Model.py +++ b/src/Model.py @@ -33,8 +33,8 @@ class NgramModel(torch.nn.Module): weight = next(self.parameters()).data if torch.cuda.is_available(): - hidden = (weight.new(self.n_layers, batch_size, self.n_hidden).zero_().cuda(),weight.new(self.n_layers, batch_size, self.n_hidden).zero_().cuda()) + hidden = weight.new(self.n_layers, batch_size, self.n_hidden).zero_().cuda() else: - hidden = (weight.new(self.n_layers, batch_size, self.n_hidden).zero_(),weight.new(self.n_layers, batch_size, self.n_hidden).zero_()) + hidden = weight.new(self.n_layers, batch_size, self.n_hidden).zero_() return hidden diff --git a/src/train.py b/src/train.py index 93f8116..5692dde 100644 --- a/src/train.py +++ b/src/train.py @@ -43,6 +43,7 @@ def get_ngrams(data, ngram_len=2): percentage = round((counter/len(data))*100, 2) print(f"Status: {percentage}%", end='\r') + print("Creating one list") n_grams = sum(n_grams, []) print("Created ngrams") return n_grams @@ -103,11 +104,13 @@ def train(net, source_int, target_int, seed, epochs=5, batch_size=32, lr=0.001, counter = 0 + print("Start training") + torch.autograd.set_detect_anomaly(True) net.train() - for epoch in range(epochs): - h = net.init_hidden(batch_size) + hidden = net.init_hidden(batch_size) + #import ipdb;ipdb.set_trace() for x,y in get_batches(source_int, target_int, batch_size): counter +=1 @@ -117,21 +120,26 @@ def train(net, source_int, target_int, seed, epochs=5, batch_size=32, lr=0.001, source = source.cuda() target = target.cuda() - h = tuple([each.data for each in h]) + #hidden = tuple([each.data for each in hidden]) net.zero_grad() - output, h = net(source, h) + output, hidden = net(source, hidden) + hidden.detach_() loss = criterion(output, target.view(-1)) + #if counter == 1: + # loss.backward(retain_graph=True) + #else: + # loss.backward() loss.backward() - nn.utils.clip_grad_norm_(net.parameters(), clip) + torch.nn.utils.clip_grad_norm_(net.parameters(), clip) optimizer.step() if counter % step == 0: - print(f"Epoch: {epoch}/{epochs} ; Step : {counter}") + print(f"Epoch: {epoch}/{epochs} ; Step : {counter} ; loss : {loss}") if counter % 500 == 0: torch.save(net.state_dict(), f"checkpoint.ckpt-{counter}-epoch_{epoch}-seed_{seed}") @@ -149,16 +157,16 @@ def main(): if args.ngrams: print("Reading ngrams") with open(args.ngrams, 'rb') as f: - source, target, data = pickle.load(f) + source, target, data, n_grams = pickle.load(f) print("Ngrams read") else: data = read_clear_data(args.in_file) n_grams = get_ngrams(data, args.ngram_level) source, target = segment_data(n_grams) print("Saving progress...") - with open(f"n_grams-ngram_{ngram_level}-seed_{seed}.pickle", 'wb+') as f: - pickle.dump((source, target, data), f) - print(f"Saved: n_grams-ngram_{ngram_level}-seed_{seed}.pickle") + with open(f"n_grams-ngram_{args.ngram_level}-seed_{seed}.pickle", 'wb+') as f: + pickle.dump((source, target, data, n_grams), f) + print(f"Saved: n_grams-ngram_{args.ngram_level}-seed_{seed}.pickle") if args.vocab: print("Reading vocab")