Use Marian pretrained
This commit is contained in:
parent
c4cf2343e5
commit
e5d8b26718
1560
dev-0/out.tsv
1560
dev-0/out.tsv
File diff suppressed because it is too large
Load Diff
30
simple_translator.py
Normal file
30
simple_translator.py
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
from transformers import MarianTokenizer, MarianMTModel
|
||||||
|
import sys
|
||||||
|
from typing import List
|
||||||
|
from numba import jit
|
||||||
|
|
||||||
|
@jit
|
||||||
|
def count():
|
||||||
|
data={}
|
||||||
|
for doc_id,line in enumerate(sys.stdin):
|
||||||
|
data[doc_id]=line.rstrip()
|
||||||
|
return data
|
||||||
|
|
||||||
|
def translate(data):
|
||||||
|
for key in data.keys():
|
||||||
|
batch = tok.prepare_seq2seq_batch(src_texts=[data[key]])
|
||||||
|
gen = model.generate(**batch)
|
||||||
|
translate = tok.batch_decode(gen, skip_special_tokens=True)
|
||||||
|
print(translate[0])
|
||||||
|
|
||||||
|
if __name__ =="__main__":
|
||||||
|
src = 'pl' # source language
|
||||||
|
trg = 'en' # target language
|
||||||
|
mname = f'Helsinki-NLP/opus-mt-{src}-{trg}'
|
||||||
|
|
||||||
|
#print('Data ready!')
|
||||||
|
model = MarianMTModel.from_pretrained(mname)
|
||||||
|
tok = MarianTokenizer.from_pretrained(mname)
|
||||||
|
data=count()
|
||||||
|
translate(data)
|
2322
test-A/out.tsv
2322
test-A/out.tsv
File diff suppressed because it is too large
Load Diff
10
translate.py
10
translate.py
@ -1,10 +0,0 @@
|
|||||||
# -*- coding: utf-8 -*-
|
|
||||||
import sys
|
|
||||||
from googletrans import Translator
|
|
||||||
|
|
||||||
translator = Translator()
|
|
||||||
|
|
||||||
for line in sys.stdin:
|
|
||||||
sentence = line.rstrip()
|
|
||||||
translation = translator.translate(sentence, src='pl', dest='en')
|
|
||||||
print(translation.text)
|
|
Loading…
Reference in New Issue
Block a user