# -*- coding: utf-8 -*- from transformers import MarianTokenizer, MarianMTModel import sys from typing import List from numba import jit @jit def count(): data={} for doc_id,line in enumerate(sys.stdin): data[doc_id]=line.rstrip() return data def translate(data): for key in data.keys(): batch = tok.prepare_seq2seq_batch(src_texts=[data[key]]) gen = model.generate(**batch) translate = tok.batch_decode(gen, skip_special_tokens=True) print(translate[0]) if __name__ =="__main__": src = 'pl' # source language trg = 'en' # target language mname = f'Helsinki-NLP/opus-mt-{src}-{trg}' #print('Data ready!') model = MarianMTModel.from_pretrained(mname) tok = MarianTokenizer.from_pretrained(mname) data=count() translate(data)