wmt-2017-cs-en/train.py

7 lines
308 B
Python
Raw Permalink Normal View History

2021-01-27 04:01:04 +01:00
from lstm_model import EncoderRNN, DecoderRNN, AttnDecoderRNN
from model_train import *
hidden_size = 256
encoder1 = EncoderRNN(input_lang.n_words, hidden_size).to(device)
attn_decoder1 = AttnDecoderRNN(hidden_size, output_lang.n_words).to(device)
trainIters(encoder1, attn_decoder1, 10000, print_every=100)