74 lines
2.4 KiB
Python
Executable File
74 lines
2.4 KiB
Python
Executable File
#!/usr/bin/env python
|
|
# -*- coding: utf-8 -*-
|
|
|
|
import numpy as np
|
|
from tqdm import tqdm
|
|
from typing import List
|
|
|
|
from fairseq.models.roberta import RobertaModel
|
|
from collections import OrderedDict
|
|
import torch
|
|
|
|
|
|
def get_batches(data_path: str, max_seq: int,
|
|
batch_size: int, pad_index: int) -> List[torch.Tensor]:
|
|
lines = []
|
|
with open(data_path, 'rt') as f:
|
|
for line in tqdm(f, desc=f'Reading {data_path}'):
|
|
line = roberta.encode(line.rstrip('\n'))[:max_seq]
|
|
lines.append(line)
|
|
|
|
tensor_list = []
|
|
for i in tqdm(range(0, len(lines), batch_size), desc='Batching'):
|
|
batch_text = lines[i: i + batch_size]
|
|
# Get max length of batch
|
|
max_len = max([tokens.size(0) for tokens in batch_text])
|
|
|
|
# Create empty tensor with padding index
|
|
input_tensor = torch.LongTensor(len(batch_text), max_len).fill_(pad_index)
|
|
# Fill tensor with tokens
|
|
for i, tokens in enumerate(batch_text):
|
|
input_tensor[i][:tokens.size(0)] = tokens
|
|
tensor_list.append(input_tensor)
|
|
|
|
return tensor_list
|
|
|
|
|
|
def predict(roberta: RobertaModel, batches: List[torch.Tensor], save_file: str):
|
|
with open(save_file, 'wt') as fout:
|
|
for batch in tqdm(batches, desc='Processing'):
|
|
raw_prediction = roberta.predict('hesaid', batch)
|
|
# Get probability for second class (M class)
|
|
out_tensor = torch.exp(raw_prediction[:, 1])
|
|
for line_prediction in out_tensor:
|
|
# Get probability for first class
|
|
fout.write(f'{line_prediction.item()}\n')
|
|
|
|
|
|
def load_model():
|
|
roberta = RobertaModel.from_pretrained(
|
|
model_name_or_path='checkpoints',
|
|
data_name_or_path='data-bin',
|
|
sentencepiece_vocab='roberta_base_fairseq/sentencepiece.bpe.model',
|
|
checkpoint_file='checkpoint_best.pt',
|
|
bpe='sentencepiece',
|
|
)
|
|
return roberta
|
|
|
|
|
|
if __name__ == '__main__':
|
|
roberta = load_model()
|
|
roberta.cuda()
|
|
roberta.train()
|
|
|
|
max_seq = 512
|
|
batch_size = 5
|
|
pad_index = roberta.task.source_dictionary.pad()
|
|
|
|
for dir_name in ['dev-0', 'dev-1', 'test-A']:
|
|
batches = get_batches(f'data/{dir_name}/in.tsv', max_seq, batch_size, pad_index)
|
|
for i in range(12):
|
|
print(f'Processing iteration: {i}')
|
|
j = str(i)
|
|
predict(roberta, batches, f'data/{dir_name}/out.tsv' + j)
|