pbr-private/translate.py
2022-05-22 15:47:22 +01:00

32 lines
875 B
Python

import itertools
import nltk
from nltk import tokenize
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
nltk.download('punkt')
with open('data/text.txt') as f:
lines = f.readlines()
sentences = tokenize.sent_tokenize(' '.join(lines))
model = MBartForConditionalGeneration.from_pretrained("model")
tokenizer = MBart50TokenizerFast.from_pretrained("model", src_lang="en_XX")
returns = []
for sentence in sentences:
model_inputs = tokenizer(sentence, return_tensors="pt")
generated_tokens = model.generate(
**model_inputs,
forced_bos_token_id=tokenizer.lang_code_to_id["pl_PL"]
)
returns.append(tokenizer.batch_decode(generated_tokens, skip_special_tokens=True))
returns = list(itertools.chain(*returns))
with open('translation_output.txt', 'w') as f:
for line in returns:
f.write(line + ' ')