2022-05-29 19:32:42 +02:00
|
|
|
import io
|
2022-05-22 16:47:22 +02:00
|
|
|
import itertools
|
|
|
|
import nltk
|
|
|
|
from nltk import tokenize
|
|
|
|
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
|
|
|
|
|
|
|
|
nltk.download('punkt')
|
|
|
|
|
2022-05-29 19:32:42 +02:00
|
|
|
with io.open('data/text.txt', 'r', encoding='utf8') as f:
|
2022-05-22 16:47:22 +02:00
|
|
|
lines = f.readlines()
|
|
|
|
|
|
|
|
sentences = tokenize.sent_tokenize(' '.join(lines))
|
|
|
|
|
|
|
|
|
|
|
|
model = MBartForConditionalGeneration.from_pretrained("model")
|
2022-05-29 19:32:42 +02:00
|
|
|
tokenizer = MBart50TokenizerFast.from_pretrained("model", src_lang="pl_PL")
|
2022-05-22 16:47:22 +02:00
|
|
|
|
|
|
|
returns = []
|
|
|
|
for sentence in sentences:
|
|
|
|
model_inputs = tokenizer(sentence, return_tensors="pt")
|
|
|
|
|
|
|
|
generated_tokens = model.generate(
|
|
|
|
**model_inputs,
|
2022-05-29 19:32:42 +02:00
|
|
|
forced_bos_token_id=tokenizer.lang_code_to_id["en_XX"]
|
2022-05-22 16:47:22 +02:00
|
|
|
)
|
|
|
|
returns.append(tokenizer.batch_decode(generated_tokens, skip_special_tokens=True))
|
|
|
|
|
|
|
|
returns = list(itertools.chain(*returns))
|
|
|
|
|
2022-05-29 19:32:42 +02:00
|
|
|
with io.open('translation_output.txt', 'w', encoding='utf8') as f:
|
2022-05-22 16:47:22 +02:00
|
|
|
for line in returns:
|
|
|
|
f.write(line + ' ')
|