8 lines
283 B
Python
8 lines
283 B
Python
from lstm_model import EncoderRNN, DecoderRNN
|
|
from model_train import *
|
|
hidden_size = 256
|
|
encoder1 = EncoderRNN(input_lang.n_words, hidden_size).to(device)
|
|
attn_decoder1 = DecoderRNN(hidden_size, output_lang.n_words).to(device)
|
|
|
|
trainIters(encoder1, attn_decoder1, 5, print_every=1)
|