#!/usr/bin/env python # -*- coding: utf-8 -*- import numpy as np from tqdm import tqdm from typing import List from fairseq.models.roberta import RobertaModel from collections import OrderedDict import torch def get_batches(data_path: str, max_seq: int, batch_size: int, pad_index: int) -> List[torch.Tensor]: lines = [] with open(data_path, 'rt') as f: for line in tqdm(f, desc=f'Reading {data_path}'): line = roberta.encode(line.rstrip('\n'))[:max_seq] lines.append(line) tensor_list = [] for i in tqdm(range(0, len(lines), batch_size), desc='Batching'): batch_text = lines[i: i + batch_size] # Get max length of batch max_len = max([tokens.size(0) for tokens in batch_text]) # Create empty tensor with padding index input_tensor = torch.LongTensor(len(batch_text), max_len).fill_(pad_index) # Fill tensor with tokens for i, tokens in enumerate(batch_text): input_tensor[i][:tokens.size(0)] = tokens tensor_list.append(input_tensor) return tensor_list def predict(roberta: RobertaModel, batches: List[torch.Tensor], save_file: str): with open(save_file, 'wt') as fout: for batch in tqdm(batches, desc='Processing'): raw_prediction = roberta.predict('hesaid', batch) # Get probability for second class (M class) out_tensor = torch.exp(raw_prediction[:, 1]) for line_prediction in out_tensor: # Get probability for first class fout.write(f'{line_prediction.item()}\n') def load_model(): roberta = RobertaModel.from_pretrained( model_name_or_path='checkpoints', data_name_or_path='data-bin', sentencepiece_vocab='roberta_base_fairseq/sentencepiece.bpe.model', checkpoint_file='checkpoint_best.pt', bpe='sentencepiece', ) return roberta if __name__ == '__main__': roberta = load_model() roberta.cuda() roberta.train() max_seq = 512 batch_size = 5 pad_index = roberta.task.source_dictionary.pad() for dir_name in ['dev-0', 'dev-1', 'test-A']: batches = get_batches(f'data/{dir_name}/in.tsv', max_seq, batch_size, pad_index) for i in range(12): print(f'Processing iteration: {i}') j = str(i) predict(roberta, batches, f'data/{dir_name}/out.tsv' + j)