Fix in train

This commit is contained in:
SzamanFL 2021-01-08 20:17:07 +01:00
parent bdc1e902e8
commit ee23fd9d0f
2 changed files with 20 additions and 12 deletions

View File

@ -33,8 +33,8 @@ class NgramModel(torch.nn.Module):
weight = next(self.parameters()).data weight = next(self.parameters()).data
if torch.cuda.is_available(): if torch.cuda.is_available():
hidden = (weight.new(self.n_layers, batch_size, self.n_hidden).zero_().cuda(),weight.new(self.n_layers, batch_size, self.n_hidden).zero_().cuda()) hidden = weight.new(self.n_layers, batch_size, self.n_hidden).zero_().cuda()
else: else:
hidden = (weight.new(self.n_layers, batch_size, self.n_hidden).zero_(),weight.new(self.n_layers, batch_size, self.n_hidden).zero_()) hidden = weight.new(self.n_layers, batch_size, self.n_hidden).zero_()
return hidden return hidden

View File

@ -43,6 +43,7 @@ def get_ngrams(data, ngram_len=2):
percentage = round((counter/len(data))*100, 2) percentage = round((counter/len(data))*100, 2)
print(f"Status: {percentage}%", end='\r') print(f"Status: {percentage}%", end='\r')
print("Creating one list")
n_grams = sum(n_grams, []) n_grams = sum(n_grams, [])
print("Created ngrams") print("Created ngrams")
return n_grams return n_grams
@ -103,11 +104,13 @@ def train(net, source_int, target_int, seed, epochs=5, batch_size=32, lr=0.001,
counter = 0 counter = 0
print("Start training")
torch.autograd.set_detect_anomaly(True)
net.train() net.train()
for epoch in range(epochs): for epoch in range(epochs):
h = net.init_hidden(batch_size) hidden = net.init_hidden(batch_size)
#import ipdb;ipdb.set_trace()
for x,y in get_batches(source_int, target_int, batch_size): for x,y in get_batches(source_int, target_int, batch_size):
counter +=1 counter +=1
@ -117,21 +120,26 @@ def train(net, source_int, target_int, seed, epochs=5, batch_size=32, lr=0.001,
source = source.cuda() source = source.cuda()
target = target.cuda() target = target.cuda()
h = tuple([each.data for each in h]) #hidden = tuple([each.data for each in hidden])
net.zero_grad() net.zero_grad()
output, h = net(source, h) output, hidden = net(source, hidden)
hidden.detach_()
loss = criterion(output, target.view(-1)) loss = criterion(output, target.view(-1))
#if counter == 1:
# loss.backward(retain_graph=True)
#else:
# loss.backward()
loss.backward() loss.backward()
nn.utils.clip_grad_norm_(net.parameters(), clip) torch.nn.utils.clip_grad_norm_(net.parameters(), clip)
optimizer.step() optimizer.step()
if counter % step == 0: if counter % step == 0:
print(f"Epoch: {epoch}/{epochs} ; Step : {counter}") print(f"Epoch: {epoch}/{epochs} ; Step : {counter} ; loss : {loss}")
if counter % 500 == 0: if counter % 500 == 0:
torch.save(net.state_dict(), f"checkpoint.ckpt-{counter}-epoch_{epoch}-seed_{seed}") torch.save(net.state_dict(), f"checkpoint.ckpt-{counter}-epoch_{epoch}-seed_{seed}")
@ -149,16 +157,16 @@ def main():
if args.ngrams: if args.ngrams:
print("Reading ngrams") print("Reading ngrams")
with open(args.ngrams, 'rb') as f: with open(args.ngrams, 'rb') as f:
source, target, data = pickle.load(f) source, target, data, n_grams = pickle.load(f)
print("Ngrams read") print("Ngrams read")
else: else:
data = read_clear_data(args.in_file) data = read_clear_data(args.in_file)
n_grams = get_ngrams(data, args.ngram_level) n_grams = get_ngrams(data, args.ngram_level)
source, target = segment_data(n_grams) source, target = segment_data(n_grams)
print("Saving progress...") print("Saving progress...")
with open(f"n_grams-ngram_{ngram_level}-seed_{seed}.pickle", 'wb+') as f: with open(f"n_grams-ngram_{args.ngram_level}-seed_{seed}.pickle", 'wb+') as f:
pickle.dump((source, target, data), f) pickle.dump((source, target, data, n_grams), f)
print(f"Saved: n_grams-ngram_{ngram_level}-seed_{seed}.pickle") print(f"Saved: n_grams-ngram_{args.ngram_level}-seed_{seed}.pickle")
if args.vocab: if args.vocab:
print("Reading vocab") print("Reading vocab")