diff --git a/src/predict.py b/src/predict.py index 3c4a1d1..c88fccf 100644 --- a/src/predict.py +++ b/src/predict.py @@ -12,7 +12,7 @@ from Encoder import Encoder device = torch.device("cuda" if torch.cuda.is_available() else "cpu") SOS=0 EOS=1 -MAX_LENGTH=25 +MAX_LENGTH=40 def clear_line(string): string = ''.join( diff --git a/src/train.py b/src/train.py index 6a40ddf..ca6abe4 100644 --- a/src/train.py +++ b/src/train.py @@ -13,7 +13,7 @@ from Vocab import Vocab from Encoder import Encoder from Decoder import Decoder -MAX_LENGTH = 25 +MAX_LENGTH = 40 SOS=0 EOS=1 teacher_forcing_ratio=0.5