wmt-2020-pl-en/train.py

8 lines
283 B
Python
Raw Normal View History

2021-01-31 16:54:20 +01:00
from lstm_model import EncoderRNN, DecoderRNN
2021-01-27 03:54:11 +01:00
from model_train import *
hidden_size = 256
encoder1 = EncoderRNN(input_lang.n_words, hidden_size).to(device)
2021-01-31 16:54:20 +01:00
attn_decoder1 = DecoderRNN(hidden_size, output_lang.n_words).to(device)
2021-01-27 03:54:11 +01:00
2021-01-31 16:54:20 +01:00
trainIters(encoder1, attn_decoder1, 5, print_every=1)