32 lines
875 B
Python
32 lines
875 B
Python
|
import itertools
|
||
|
import nltk
|
||
|
from nltk import tokenize
|
||
|
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
|
||
|
|
||
|
nltk.download('punkt')
|
||
|
|
||
|
with open('data/text.txt') as f:
|
||
|
lines = f.readlines()
|
||
|
|
||
|
sentences = tokenize.sent_tokenize(' '.join(lines))
|
||
|
|
||
|
|
||
|
model = MBartForConditionalGeneration.from_pretrained("model")
|
||
|
tokenizer = MBart50TokenizerFast.from_pretrained("model", src_lang="en_XX")
|
||
|
|
||
|
returns = []
|
||
|
for sentence in sentences:
|
||
|
model_inputs = tokenizer(sentence, return_tensors="pt")
|
||
|
|
||
|
generated_tokens = model.generate(
|
||
|
**model_inputs,
|
||
|
forced_bos_token_id=tokenizer.lang_code_to_id["pl_PL"]
|
||
|
)
|
||
|
returns.append(tokenizer.batch_decode(generated_tokens, skip_special_tokens=True))
|
||
|
|
||
|
returns = list(itertools.chain(*returns))
|
||
|
|
||
|
with open('translation_output.txt', 'w') as f:
|
||
|
for line in returns:
|
||
|
f.write(line + ' ')
|