Fix in train
This commit is contained in:
parent
bdc1e902e8
commit
ee23fd9d0f
@ -33,8 +33,8 @@ class NgramModel(torch.nn.Module):
|
|||||||
weight = next(self.parameters()).data
|
weight = next(self.parameters()).data
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
hidden = (weight.new(self.n_layers, batch_size, self.n_hidden).zero_().cuda(),weight.new(self.n_layers, batch_size, self.n_hidden).zero_().cuda())
|
hidden = weight.new(self.n_layers, batch_size, self.n_hidden).zero_().cuda()
|
||||||
else:
|
else:
|
||||||
hidden = (weight.new(self.n_layers, batch_size, self.n_hidden).zero_(),weight.new(self.n_layers, batch_size, self.n_hidden).zero_())
|
hidden = weight.new(self.n_layers, batch_size, self.n_hidden).zero_()
|
||||||
|
|
||||||
return hidden
|
return hidden
|
||||||
|
28
src/train.py
28
src/train.py
@ -43,6 +43,7 @@ def get_ngrams(data, ngram_len=2):
|
|||||||
percentage = round((counter/len(data))*100, 2)
|
percentage = round((counter/len(data))*100, 2)
|
||||||
print(f"Status: {percentage}%", end='\r')
|
print(f"Status: {percentage}%", end='\r')
|
||||||
|
|
||||||
|
print("Creating one list")
|
||||||
n_grams = sum(n_grams, [])
|
n_grams = sum(n_grams, [])
|
||||||
print("Created ngrams")
|
print("Created ngrams")
|
||||||
return n_grams
|
return n_grams
|
||||||
@ -103,11 +104,13 @@ def train(net, source_int, target_int, seed, epochs=5, batch_size=32, lr=0.001,
|
|||||||
|
|
||||||
counter = 0
|
counter = 0
|
||||||
|
|
||||||
|
print("Start training")
|
||||||
|
torch.autograd.set_detect_anomaly(True)
|
||||||
net.train()
|
net.train()
|
||||||
|
|
||||||
for epoch in range(epochs):
|
for epoch in range(epochs):
|
||||||
h = net.init_hidden(batch_size)
|
hidden = net.init_hidden(batch_size)
|
||||||
|
|
||||||
|
#import ipdb;ipdb.set_trace()
|
||||||
for x,y in get_batches(source_int, target_int, batch_size):
|
for x,y in get_batches(source_int, target_int, batch_size):
|
||||||
counter +=1
|
counter +=1
|
||||||
|
|
||||||
@ -117,21 +120,26 @@ def train(net, source_int, target_int, seed, epochs=5, batch_size=32, lr=0.001,
|
|||||||
source = source.cuda()
|
source = source.cuda()
|
||||||
target = target.cuda()
|
target = target.cuda()
|
||||||
|
|
||||||
h = tuple([each.data for each in h])
|
#hidden = tuple([each.data for each in hidden])
|
||||||
|
|
||||||
net.zero_grad()
|
net.zero_grad()
|
||||||
|
|
||||||
output, h = net(source, h)
|
output, hidden = net(source, hidden)
|
||||||
|
hidden.detach_()
|
||||||
|
|
||||||
loss = criterion(output, target.view(-1))
|
loss = criterion(output, target.view(-1))
|
||||||
|
|
||||||
|
#if counter == 1:
|
||||||
|
# loss.backward(retain_graph=True)
|
||||||
|
#else:
|
||||||
|
# loss.backward()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
|
||||||
nn.utils.clip_grad_norm_(net.parameters(), clip)
|
torch.nn.utils.clip_grad_norm_(net.parameters(), clip)
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
||||||
if counter % step == 0:
|
if counter % step == 0:
|
||||||
print(f"Epoch: {epoch}/{epochs} ; Step : {counter}")
|
print(f"Epoch: {epoch}/{epochs} ; Step : {counter} ; loss : {loss}")
|
||||||
|
|
||||||
if counter % 500 == 0:
|
if counter % 500 == 0:
|
||||||
torch.save(net.state_dict(), f"checkpoint.ckpt-{counter}-epoch_{epoch}-seed_{seed}")
|
torch.save(net.state_dict(), f"checkpoint.ckpt-{counter}-epoch_{epoch}-seed_{seed}")
|
||||||
@ -149,16 +157,16 @@ def main():
|
|||||||
if args.ngrams:
|
if args.ngrams:
|
||||||
print("Reading ngrams")
|
print("Reading ngrams")
|
||||||
with open(args.ngrams, 'rb') as f:
|
with open(args.ngrams, 'rb') as f:
|
||||||
source, target, data = pickle.load(f)
|
source, target, data, n_grams = pickle.load(f)
|
||||||
print("Ngrams read")
|
print("Ngrams read")
|
||||||
else:
|
else:
|
||||||
data = read_clear_data(args.in_file)
|
data = read_clear_data(args.in_file)
|
||||||
n_grams = get_ngrams(data, args.ngram_level)
|
n_grams = get_ngrams(data, args.ngram_level)
|
||||||
source, target = segment_data(n_grams)
|
source, target = segment_data(n_grams)
|
||||||
print("Saving progress...")
|
print("Saving progress...")
|
||||||
with open(f"n_grams-ngram_{ngram_level}-seed_{seed}.pickle", 'wb+') as f:
|
with open(f"n_grams-ngram_{args.ngram_level}-seed_{seed}.pickle", 'wb+') as f:
|
||||||
pickle.dump((source, target, data), f)
|
pickle.dump((source, target, data, n_grams), f)
|
||||||
print(f"Saved: n_grams-ngram_{ngram_level}-seed_{seed}.pickle")
|
print(f"Saved: n_grams-ngram_{args.ngram_level}-seed_{seed}.pickle")
|
||||||
|
|
||||||
if args.vocab:
|
if args.vocab:
|
||||||
print("Reading vocab")
|
print("Reading vocab")
|
||||||
|
Loading…
Reference in New Issue
Block a user