Some solution
This commit is contained in:
parent
46763e76c2
commit
eaf79ed846
BIN
src/models/decoder-70000-23
Normal file
BIN
src/models/decoder-70000-23
Normal file
Binary file not shown.
BIN
src/models/encoder-70000-23
Normal file
BIN
src/models/encoder-70000-23
Normal file
Binary file not shown.
@ -96,8 +96,8 @@ def main():
|
||||
encoder = Encoder(input_vocab.size, hidden_size).to(device)
|
||||
decoder = Decoder(hidden_size,target_vocab.size).to(device)
|
||||
|
||||
encoder.load_state_dict(torch.load(args.encoder))
|
||||
decoder.load_state_dict(torch.load(args.decoder))
|
||||
encoder.load_state_dict(torch.load(args.encoder, map_location=torch.device('cpu')))
|
||||
decoder.load_state_dict(torch.load(args.decoder, map_location=torch.device('cpu')))
|
||||
|
||||
data = read_clear_data(args.in_f)
|
||||
with open(args.out_f, 'w+') as f:
|
||||
|
11
src/train.py
11
src/train.py
@ -111,7 +111,7 @@ def train(input_tensor, target_tensor, encoder, decoder, encoder_optim, decoder_
|
||||
encoder_optim.step()
|
||||
return loss.item()/ target_len
|
||||
|
||||
def train_iterate(pairs, encoder, decoder, n_iters, input_vocab, target_vocab, lr=0.01):
|
||||
def train_iterate(pairs, encoder, decoder, n_iters, input_vocab, target_vocab, lr=1, start_from=1):
|
||||
encoder_optim = torch.optim.SGD(encoder.parameters(), lr=lr)
|
||||
decoder_optim = torch.optim.SGD(decoder.parameters(), lr=lr)
|
||||
#import ipdb; ipdb.set_trace()
|
||||
@ -120,7 +120,7 @@ def train_iterate(pairs, encoder, decoder, n_iters, input_vocab, target_vocab, l
|
||||
loss_total=0
|
||||
|
||||
print("Start training")
|
||||
for i in range(1, n_iters + 1):
|
||||
for i in range(start_from, n_iters + 1):
|
||||
training_pair = training_pairs[i - 1]
|
||||
input_tensor = training_pair[0]
|
||||
target_tensor = training_pair[1]
|
||||
@ -130,7 +130,7 @@ def train_iterate(pairs, encoder, decoder, n_iters, input_vocab, target_vocab, l
|
||||
|
||||
if i % 1000 == 0:
|
||||
loss_avg = loss_total / 1000
|
||||
print(f"lavg loss: {loss_avg}")
|
||||
print(f"step: {i} : avg loss: {loss_avg}")
|
||||
loss_total = 0
|
||||
|
||||
if i % 5000 == 0:
|
||||
@ -145,6 +145,7 @@ def main():
|
||||
parser.add_argument("--encoder")
|
||||
parser.add_argument("--decoder")
|
||||
parser.add_argument("--seed")
|
||||
parser.add_argument("--start", default=1)
|
||||
args = parser.parse_args()
|
||||
|
||||
global seed
|
||||
@ -172,10 +173,10 @@ def main():
|
||||
if args.encoder:
|
||||
encoder.load_state_dict(torch.load(args.encoder))
|
||||
|
||||
checkpoint = True
|
||||
#checkpoint = True
|
||||
if args.decoder:
|
||||
decoder.load_state_dict(torch.load(args.decoder))
|
||||
|
||||
train_iterate(pairs, encoder, decoder, 70000, input_vocab, target_vocab)
|
||||
train_iterate(pairs, encoder, decoder, 70000, input_vocab, target_vocab, start_from=int(args.start))
|
||||
|
||||
main()
|
||||
|
1200
test-A/out.tsv
Normal file
1200
test-A/out.tsv
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user