#!/usr/bin/env python3 # -*- coding: utf-8 -*- import torch from fairseq.models.roberta import RobertaModel from tqdm import tqdm if __name__ == '__main__': roberta = RobertaModel.from_pretrained( model_name_or_path='checkpoints', data_name_or_path='data-bin', sentencepiece_vocab='roberta_base_fairseq/sentencepiece.bpe.model', checkpoint_file='checkpoint_best.pt', bpe='sentencepiece', ) roberta.cuda() roberta.eval() max_seq = 512 batch_size = 5 pad_index = roberta.task.source_dictionary.pad() for dir_test in ['dev-0', 'dev-1', 'test-A']: lines = [] with open(f'data/{dir_test}/in.tsv', 'rt') as f: for line in tqdm(f, desc=f'Reading {dir_test}'): line = roberta.encode(line.rstrip('\n'))[:max_seq] lines.append(line) predictions = [] for i in tqdm(range(0, len(lines), batch_size), desc='Processing'): batch_text = lines[i: i + batch_size] # Get max length of batch max_len = max([tokens.size(0) for tokens in batch_text]) # Create empty tensor with padding index input_tensor = torch.LongTensor(len(batch_text), max_len).fill_(pad_index) # Fill tensor with tokens for i, tokens in enumerate(batch_text): input_tensor[i][:tokens.size(0)] = tokens with torch.no_grad(): raw_prediction = roberta.predict('hesaid', input_tensor) # Get probability for second class (M class) out_tensor = torch.exp(raw_prediction[:, 1]) for line_prediction in out_tensor: # Get probability for first class predictions.append(line_prediction.item()) with open(f'data/{dir_test}/out.tsv', 'wt') as fw: fw.write('\n'.join([f'{p:.8f}' for p in predictions]))