Some solution
This commit is contained in:
parent
46763e76c2
commit
eaf79ed846
BIN
src/models/decoder-70000-23
Normal file
BIN
src/models/decoder-70000-23
Normal file
Binary file not shown.
BIN
src/models/encoder-70000-23
Normal file
BIN
src/models/encoder-70000-23
Normal file
Binary file not shown.
@ -96,8 +96,8 @@ def main():
|
|||||||
encoder = Encoder(input_vocab.size, hidden_size).to(device)
|
encoder = Encoder(input_vocab.size, hidden_size).to(device)
|
||||||
decoder = Decoder(hidden_size,target_vocab.size).to(device)
|
decoder = Decoder(hidden_size,target_vocab.size).to(device)
|
||||||
|
|
||||||
encoder.load_state_dict(torch.load(args.encoder))
|
encoder.load_state_dict(torch.load(args.encoder, map_location=torch.device('cpu')))
|
||||||
decoder.load_state_dict(torch.load(args.decoder))
|
decoder.load_state_dict(torch.load(args.decoder, map_location=torch.device('cpu')))
|
||||||
|
|
||||||
data = read_clear_data(args.in_f)
|
data = read_clear_data(args.in_f)
|
||||||
with open(args.out_f, 'w+') as f:
|
with open(args.out_f, 'w+') as f:
|
||||||
|
11
src/train.py
11
src/train.py
@ -111,7 +111,7 @@ def train(input_tensor, target_tensor, encoder, decoder, encoder_optim, decoder_
|
|||||||
encoder_optim.step()
|
encoder_optim.step()
|
||||||
return loss.item()/ target_len
|
return loss.item()/ target_len
|
||||||
|
|
||||||
def train_iterate(pairs, encoder, decoder, n_iters, input_vocab, target_vocab, lr=0.01):
|
def train_iterate(pairs, encoder, decoder, n_iters, input_vocab, target_vocab, lr=1, start_from=1):
|
||||||
encoder_optim = torch.optim.SGD(encoder.parameters(), lr=lr)
|
encoder_optim = torch.optim.SGD(encoder.parameters(), lr=lr)
|
||||||
decoder_optim = torch.optim.SGD(decoder.parameters(), lr=lr)
|
decoder_optim = torch.optim.SGD(decoder.parameters(), lr=lr)
|
||||||
#import ipdb; ipdb.set_trace()
|
#import ipdb; ipdb.set_trace()
|
||||||
@ -120,7 +120,7 @@ def train_iterate(pairs, encoder, decoder, n_iters, input_vocab, target_vocab, l
|
|||||||
loss_total=0
|
loss_total=0
|
||||||
|
|
||||||
print("Start training")
|
print("Start training")
|
||||||
for i in range(1, n_iters + 1):
|
for i in range(start_from, n_iters + 1):
|
||||||
training_pair = training_pairs[i - 1]
|
training_pair = training_pairs[i - 1]
|
||||||
input_tensor = training_pair[0]
|
input_tensor = training_pair[0]
|
||||||
target_tensor = training_pair[1]
|
target_tensor = training_pair[1]
|
||||||
@ -130,7 +130,7 @@ def train_iterate(pairs, encoder, decoder, n_iters, input_vocab, target_vocab, l
|
|||||||
|
|
||||||
if i % 1000 == 0:
|
if i % 1000 == 0:
|
||||||
loss_avg = loss_total / 1000
|
loss_avg = loss_total / 1000
|
||||||
print(f"lavg loss: {loss_avg}")
|
print(f"step: {i} : avg loss: {loss_avg}")
|
||||||
loss_total = 0
|
loss_total = 0
|
||||||
|
|
||||||
if i % 5000 == 0:
|
if i % 5000 == 0:
|
||||||
@ -145,6 +145,7 @@ def main():
|
|||||||
parser.add_argument("--encoder")
|
parser.add_argument("--encoder")
|
||||||
parser.add_argument("--decoder")
|
parser.add_argument("--decoder")
|
||||||
parser.add_argument("--seed")
|
parser.add_argument("--seed")
|
||||||
|
parser.add_argument("--start", default=1)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
global seed
|
global seed
|
||||||
@ -172,10 +173,10 @@ def main():
|
|||||||
if args.encoder:
|
if args.encoder:
|
||||||
encoder.load_state_dict(torch.load(args.encoder))
|
encoder.load_state_dict(torch.load(args.encoder))
|
||||||
|
|
||||||
checkpoint = True
|
#checkpoint = True
|
||||||
if args.decoder:
|
if args.decoder:
|
||||||
decoder.load_state_dict(torch.load(args.decoder))
|
decoder.load_state_dict(torch.load(args.decoder))
|
||||||
|
|
||||||
train_iterate(pairs, encoder, decoder, 70000, input_vocab, target_vocab)
|
train_iterate(pairs, encoder, decoder, 70000, input_vocab, target_vocab, start_from=int(args.start))
|
||||||
|
|
||||||
main()
|
main()
|
||||||
|
1200
test-A/out.tsv
Normal file
1200
test-A/out.tsv
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user