pbr-private/translate.py

33 lines
930 B
Python
Raw Permalink Normal View History

2022-05-29 19:32:42 +02:00
import io
2022-05-22 16:47:22 +02:00
import itertools
import nltk
from nltk import tokenize
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
nltk.download('punkt')
2022-05-29 19:32:42 +02:00
with io.open('data/text.txt', 'r', encoding='utf8') as f:
2022-05-22 16:47:22 +02:00
lines = f.readlines()
sentences = tokenize.sent_tokenize(' '.join(lines))
model = MBartForConditionalGeneration.from_pretrained("model")
2022-05-29 19:32:42 +02:00
tokenizer = MBart50TokenizerFast.from_pretrained("model", src_lang="pl_PL")
2022-05-22 16:47:22 +02:00
returns = []
for sentence in sentences:
model_inputs = tokenizer(sentence, return_tensors="pt")
generated_tokens = model.generate(
**model_inputs,
2022-05-29 19:32:42 +02:00
forced_bos_token_id=tokenizer.lang_code_to_id["en_XX"]
2022-05-22 16:47:22 +02:00
)
returns.append(tokenizer.batch_decode(generated_tokens, skip_special_tokens=True))
returns = list(itertools.chain(*returns))
2022-05-29 19:32:42 +02:00
with io.open('translation_output.txt', 'w', encoding='utf8') as f:
2022-05-22 16:47:22 +02:00
for line in returns:
f.write(line + ' ')