31 lines
753 B
Python
31 lines
753 B
Python
|
# -*- coding: utf-8 -*-
|
||
|
from transformers import MarianTokenizer, MarianMTModel
|
||
|
import sys
|
||
|
from typing import List
|
||
|
from numba import jit
|
||
|
|
||
|
@jit
|
||
|
def count():
|
||
|
data={}
|
||
|
for doc_id,line in enumerate(sys.stdin):
|
||
|
data[doc_id]=line.rstrip()
|
||
|
return data
|
||
|
|
||
|
def translate(data):
|
||
|
for key in data.keys():
|
||
|
batch = tok.prepare_seq2seq_batch(src_texts=[data[key]])
|
||
|
gen = model.generate(**batch)
|
||
|
translate = tok.batch_decode(gen, skip_special_tokens=True)
|
||
|
print(translate[0])
|
||
|
|
||
|
if __name__ =="__main__":
|
||
|
src = 'pl' # source language
|
||
|
trg = 'en' # target language
|
||
|
mname = f'Helsinki-NLP/opus-mt-{src}-{trg}'
|
||
|
|
||
|
#print('Data ready!')
|
||
|
model = MarianMTModel.from_pretrained(mname)
|
||
|
tok = MarianTokenizer.from_pretrained(mname)
|
||
|
data=count()
|
||
|
translate(data)
|