import torch from transformers import GPT2LMHeadModel, GPT2Tokenizer import lzma # import os # os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:512" torch.cuda.empty_cache() top = 50 model_name = "gpt2" device = torch.device('cuda') model = GPT2LMHeadModel.from_pretrained(model_name) tokenizer = GPT2Tokenizer.from_pretrained(model_name) tokenizer.truncation_side = 'left' model.to(torch.device(device)) for folder_name in ['dev-0', 'test-A']: linecount = 10519 if folder_name == 'dev-0' else 7414 processed_lines = 0 f = lzma.open(f'{folder_name}/in.tsv.xz', mode='rt', encoding='utf-8') with open(f'{folder_name}/out.tsv', 'w', encoding='utf-8') as file: for line in f: separated = line.split('\t') prefix = separated[6].replace(r'\n', ' ') suffix = separated[7].replace(r'\n', ' ') first_next_word = suffix.split()[0] #prompt = f'{prefix} [TOKEN] {suffix}\n[TOKEN] = ' inputs = tokenizer.encode(prefix, return_tensors="pt", truncation=True).to(device) output = model(inputs) probs = torch.softmax(output[0][0][-1], dim=0) result = '' total = 0 values, indices = probs.topk(top) for val, idx in zip(values, indices): final_val = val.item() token = tokenizer.decode([idx]) token = token.strip() if token in ",<>.?:;\'\"/\\{[]}|_-+=)(&%^*#@!$": continue if token in ['ia', 'ix', 'io', 'ik', 'ing']: continue # Biore pierwsze slowo z prawego kontekstu i sprawdzam czy jest jednym z tokenów przewidzianych # przez prompt złożony z lewego kontekstu i kandydata na słowo w dziurze # jesli tak to zwiększam prawdopodobieństwo tego slowa new_prompt = f'{prefix} {token} ' new_inputs = tokenizer.encode(new_prompt, return_tensors="pt", truncation=True).to(device) new_output = model(new_inputs) new_probs = torch.softmax(output[0][0][-1], dim=0) new_values, new_indices = new_probs.topk(top) for new_val, new_idx in zip(new_values, new_indices): if tokenizer.decode([new_idx]) == first_next_word: final_val += new_val.item() break total += val result += f'{token}:{final_val} ' result += f':{1 - total}' file.write(result + '\n') print(f'\r{folder_name} : {(processed_lines/linecount)*100:.2f}%', end='') processed_lines += 1 #print(processed_lines) f.close() print()