#!/usr/bin/env python # coding: utf-8 # In[1]: import os from datasets import load_dataset from transformers import AutoTokenizer, DataCollatorForLanguageModeling, AutoModelForMaskedLM, Trainer, TrainingArguments, FillMaskPipeline text_file_path = 'train.txt' model_name = 'google-bert/bert-base-uncased' output_dir = "./bert_output" if not os.path.exists(output_dir): os.makedirs(output_dir) dataset = load_dataset('text', data_files={'train': text_file_path}, streaming=False) tokenizer = AutoTokenizer.from_pretrained(model_name) def tokenize_function(examples): return tokenizer(examples['text'], padding="max_length", truncation=True, max_length=512) tokenized_dataset = dataset.map(tokenize_function, batched=True) data_collator = DataCollatorForLanguageModeling( tokenizer=tokenizer, mlm=True, mlm_probability=0.15 ) model = AutoModelForMaskedLM.from_pretrained(model_name) training_args = TrainingArguments( output_dir=output_dir, overwrite_output_dir=True, num_train_epochs=1, per_device_train_batch_size=8, save_steps=10_000, save_total_limit=2, ) trainer = Trainer( model=model, args=training_args, data_collator=data_collator, train_dataset=tokenized_dataset["train"], ) trainer.train() model.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir) fill_mask = FillMaskPipeline(model=model, tokenizer=tokenizer) result = fill_mask("This is a [MASK] example.") print(result) # In[2]: import regex as re def clean(text): text = text.replace('-\\n', '').replace('\\n', ' ').replace('\\t', ' ') text = re.sub(r'\n', ' ', text) text = re.sub(r'(?<=\w)[,-](?=\w)', '', text) text = re.sub(r'\s+', ' ', text) text = re.sub(r'\p{P}', '', text) text = text.strip() return text def predictor(prefix, suffix): prefix = clean(prefix)[-30:] suffix = clean(suffix)[:30] question = f"{prefix} [MASK] {suffix}" candidates = fill_mask(question) probs_sum = 0 output = '' for candidate in candidates: word = candidate['token_str'] prob = candidate['score'] probs_sum += prob output += f"{word}:{prob} " output += f":{1-probs_sum}" return output # In[3]: from tqdm.notebook import tqdm def generate_result(input_path, output_path='out.tsv'): lines = [] with open(input_path, encoding='utf-8') as f: for line in f: columns = line.split('\t') prefix = columns[6] suffix = columns[7] lines.append( (prefix, suffix) ) with open(output_path, 'w', encoding='utf-8') as output_file: for prefix, suffix in tqdm(lines): result = predictor(prefix, suffix) output_file.write(result + '\n') # In[4]: generate_result('dev-0/in.tsv', output_path='dev-0/out.tsv') # In[5]: get_ipython().system('shutdown -s -t 60 -c "D"')