Some solution

This commit is contained in:
SzamanFL 2021-01-26 17:03:03 +01:00
parent 46763e76c2
commit eaf79ed846
5 changed files with 1208 additions and 7 deletions

BIN
src/models/decoder-70000-23 Normal file

Binary file not shown.

BIN
src/models/encoder-70000-23 Normal file

Binary file not shown.

View File

@ -96,8 +96,8 @@ def main():
encoder = Encoder(input_vocab.size, hidden_size).to(device)
decoder = Decoder(hidden_size,target_vocab.size).to(device)
encoder.load_state_dict(torch.load(args.encoder))
decoder.load_state_dict(torch.load(args.decoder))
encoder.load_state_dict(torch.load(args.encoder, map_location=torch.device('cpu')))
decoder.load_state_dict(torch.load(args.decoder, map_location=torch.device('cpu')))
data = read_clear_data(args.in_f)
with open(args.out_f, 'w+') as f:

View File

@ -111,7 +111,7 @@ def train(input_tensor, target_tensor, encoder, decoder, encoder_optim, decoder_
encoder_optim.step()
return loss.item()/ target_len
def train_iterate(pairs, encoder, decoder, n_iters, input_vocab, target_vocab, lr=0.01):
def train_iterate(pairs, encoder, decoder, n_iters, input_vocab, target_vocab, lr=1, start_from=1):
encoder_optim = torch.optim.SGD(encoder.parameters(), lr=lr)
decoder_optim = torch.optim.SGD(decoder.parameters(), lr=lr)
#import ipdb; ipdb.set_trace()
@ -120,7 +120,7 @@ def train_iterate(pairs, encoder, decoder, n_iters, input_vocab, target_vocab, l
loss_total=0
print("Start training")
for i in range(1, n_iters + 1):
for i in range(start_from, n_iters + 1):
training_pair = training_pairs[i - 1]
input_tensor = training_pair[0]
target_tensor = training_pair[1]
@ -130,7 +130,7 @@ def train_iterate(pairs, encoder, decoder, n_iters, input_vocab, target_vocab, l
if i % 1000 == 0:
loss_avg = loss_total / 1000
print(f"lavg loss: {loss_avg}")
print(f"step: {i} : avg loss: {loss_avg}")
loss_total = 0
if i % 5000 == 0:
@ -145,6 +145,7 @@ def main():
parser.add_argument("--encoder")
parser.add_argument("--decoder")
parser.add_argument("--seed")
parser.add_argument("--start", default=1)
args = parser.parse_args()
global seed
@ -172,10 +173,10 @@ def main():
if args.encoder:
encoder.load_state_dict(torch.load(args.encoder))
checkpoint = True
#checkpoint = True
if args.decoder:
decoder.load_state_dict(torch.load(args.decoder))
train_iterate(pairs, encoder, decoder, 70000, input_vocab, target_vocab)
train_iterate(pairs, encoder, decoder, 70000, input_vocab, target_vocab, start_from=int(args.start))
main()

1200
test-A/out.tsv Normal file

File diff suppressed because it is too large Load Diff