Use Marian pretrained
This commit is contained in:
parent
c4cf2343e5
commit
e5d8b26718
1560
dev-0/out.tsv
1560
dev-0/out.tsv
File diff suppressed because it is too large
Load Diff
30
simple_translator.py
Normal file
30
simple_translator.py
Normal file
@ -0,0 +1,30 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
from transformers import MarianTokenizer, MarianMTModel
|
||||
import sys
|
||||
from typing import List
|
||||
from numba import jit
|
||||
|
||||
@jit
|
||||
def count():
|
||||
data={}
|
||||
for doc_id,line in enumerate(sys.stdin):
|
||||
data[doc_id]=line.rstrip()
|
||||
return data
|
||||
|
||||
def translate(data):
|
||||
for key in data.keys():
|
||||
batch = tok.prepare_seq2seq_batch(src_texts=[data[key]])
|
||||
gen = model.generate(**batch)
|
||||
translate = tok.batch_decode(gen, skip_special_tokens=True)
|
||||
print(translate[0])
|
||||
|
||||
if __name__ =="__main__":
|
||||
src = 'pl' # source language
|
||||
trg = 'en' # target language
|
||||
mname = f'Helsinki-NLP/opus-mt-{src}-{trg}'
|
||||
|
||||
#print('Data ready!')
|
||||
model = MarianMTModel.from_pretrained(mname)
|
||||
tok = MarianTokenizer.from_pretrained(mname)
|
||||
data=count()
|
||||
translate(data)
|
2322
test-A/out.tsv
2322
test-A/out.tsv
File diff suppressed because it is too large
Load Diff
10
translate.py
10
translate.py
@ -1,10 +0,0 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import sys
|
||||
from googletrans import Translator
|
||||
|
||||
translator = Translator()
|
||||
|
||||
for line in sys.stdin:
|
||||
sentence = line.rstrip()
|
||||
translation = translator.translate(sentence, src='pl', dest='en')
|
||||
print(translation.text)
|
Loading…
Reference in New Issue
Block a user