31 lines
753 B
Python
31 lines
753 B
Python
# -*- coding: utf-8 -*-
|
|
from transformers import MarianTokenizer, MarianMTModel
|
|
import sys
|
|
from typing import List
|
|
from numba import jit
|
|
|
|
@jit
|
|
def count():
|
|
data={}
|
|
for doc_id,line in enumerate(sys.stdin):
|
|
data[doc_id]=line.rstrip()
|
|
return data
|
|
|
|
def translate(data):
|
|
for key in data.keys():
|
|
batch = tok.prepare_seq2seq_batch(src_texts=[data[key]])
|
|
gen = model.generate(**batch)
|
|
translate = tok.batch_decode(gen, skip_special_tokens=True)
|
|
print(translate[0])
|
|
|
|
if __name__ =="__main__":
|
|
src = 'pl' # source language
|
|
trg = 'en' # target language
|
|
mname = f'Helsinki-NLP/opus-mt-{src}-{trg}'
|
|
|
|
#print('Data ready!')
|
|
model = MarianMTModel.from_pretrained(mname)
|
|
tok = MarianTokenizer.from_pretrained(mname)
|
|
data=count()
|
|
translate(data)
|