wmt-2020-pl-en/simple_translator.py
2020-10-27 00:41:38 +01:00

31 lines
753 B
Python

# -*- coding: utf-8 -*-
from transformers import MarianTokenizer, MarianMTModel
import sys
from typing import List
from numba import jit
@jit
def count():
data={}
for doc_id,line in enumerate(sys.stdin):
data[doc_id]=line.rstrip()
return data
def translate(data):
for key in data.keys():
batch = tok.prepare_seq2seq_batch(src_texts=[data[key]])
gen = model.generate(**batch)
translate = tok.batch_decode(gen, skip_special_tokens=True)
print(translate[0])
if __name__ =="__main__":
src = 'pl' # source language
trg = 'en' # target language
mname = f'Helsinki-NLP/opus-mt-{src}-{trg}'
#print('Data ready!')
model = MarianMTModel.from_pretrained(mname)
tok = MarianTokenizer.from_pretrained(mname)
data=count()
translate(data)