Fix in train
This commit is contained in:
parent
bdc1e902e8
commit
ee23fd9d0f
@ -33,8 +33,8 @@ class NgramModel(torch.nn.Module):
|
||||
weight = next(self.parameters()).data
|
||||
|
||||
if torch.cuda.is_available():
|
||||
hidden = (weight.new(self.n_layers, batch_size, self.n_hidden).zero_().cuda(),weight.new(self.n_layers, batch_size, self.n_hidden).zero_().cuda())
|
||||
hidden = weight.new(self.n_layers, batch_size, self.n_hidden).zero_().cuda()
|
||||
else:
|
||||
hidden = (weight.new(self.n_layers, batch_size, self.n_hidden).zero_(),weight.new(self.n_layers, batch_size, self.n_hidden).zero_())
|
||||
hidden = weight.new(self.n_layers, batch_size, self.n_hidden).zero_()
|
||||
|
||||
return hidden
|
||||
|
28
src/train.py
28
src/train.py
@ -43,6 +43,7 @@ def get_ngrams(data, ngram_len=2):
|
||||
percentage = round((counter/len(data))*100, 2)
|
||||
print(f"Status: {percentage}%", end='\r')
|
||||
|
||||
print("Creating one list")
|
||||
n_grams = sum(n_grams, [])
|
||||
print("Created ngrams")
|
||||
return n_grams
|
||||
@ -103,11 +104,13 @@ def train(net, source_int, target_int, seed, epochs=5, batch_size=32, lr=0.001,
|
||||
|
||||
counter = 0
|
||||
|
||||
print("Start training")
|
||||
torch.autograd.set_detect_anomaly(True)
|
||||
net.train()
|
||||
|
||||
for epoch in range(epochs):
|
||||
h = net.init_hidden(batch_size)
|
||||
hidden = net.init_hidden(batch_size)
|
||||
|
||||
#import ipdb;ipdb.set_trace()
|
||||
for x,y in get_batches(source_int, target_int, batch_size):
|
||||
counter +=1
|
||||
|
||||
@ -117,21 +120,26 @@ def train(net, source_int, target_int, seed, epochs=5, batch_size=32, lr=0.001,
|
||||
source = source.cuda()
|
||||
target = target.cuda()
|
||||
|
||||
h = tuple([each.data for each in h])
|
||||
#hidden = tuple([each.data for each in hidden])
|
||||
|
||||
net.zero_grad()
|
||||
|
||||
output, h = net(source, h)
|
||||
output, hidden = net(source, hidden)
|
||||
hidden.detach_()
|
||||
|
||||
loss = criterion(output, target.view(-1))
|
||||
|
||||
#if counter == 1:
|
||||
# loss.backward(retain_graph=True)
|
||||
#else:
|
||||
# loss.backward()
|
||||
loss.backward()
|
||||
|
||||
nn.utils.clip_grad_norm_(net.parameters(), clip)
|
||||
torch.nn.utils.clip_grad_norm_(net.parameters(), clip)
|
||||
optimizer.step()
|
||||
|
||||
if counter % step == 0:
|
||||
print(f"Epoch: {epoch}/{epochs} ; Step : {counter}")
|
||||
print(f"Epoch: {epoch}/{epochs} ; Step : {counter} ; loss : {loss}")
|
||||
|
||||
if counter % 500 == 0:
|
||||
torch.save(net.state_dict(), f"checkpoint.ckpt-{counter}-epoch_{epoch}-seed_{seed}")
|
||||
@ -149,16 +157,16 @@ def main():
|
||||
if args.ngrams:
|
||||
print("Reading ngrams")
|
||||
with open(args.ngrams, 'rb') as f:
|
||||
source, target, data = pickle.load(f)
|
||||
source, target, data, n_grams = pickle.load(f)
|
||||
print("Ngrams read")
|
||||
else:
|
||||
data = read_clear_data(args.in_file)
|
||||
n_grams = get_ngrams(data, args.ngram_level)
|
||||
source, target = segment_data(n_grams)
|
||||
print("Saving progress...")
|
||||
with open(f"n_grams-ngram_{ngram_level}-seed_{seed}.pickle", 'wb+') as f:
|
||||
pickle.dump((source, target, data), f)
|
||||
print(f"Saved: n_grams-ngram_{ngram_level}-seed_{seed}.pickle")
|
||||
with open(f"n_grams-ngram_{args.ngram_level}-seed_{seed}.pickle", 'wb+') as f:
|
||||
pickle.dump((source, target, data, n_grams), f)
|
||||
print(f"Saved: n_grams-ngram_{args.ngram_level}-seed_{seed}.pickle")
|
||||
|
||||
if args.vocab:
|
||||
print("Reading vocab")
|
||||
|
Loading…
Reference in New Issue
Block a user