petite-difference-challenge.../2-eval-large.py

83 lines
2.4 KiB
Python
Executable File

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import gzip
import logging
import lzma
import os
from typing import List
import torch
from tqdm import tqdm
from simpletransformers.classification import ClassificationModel
logger = logging.getLogger(__name__)
def open_file(path, *args):
if path.endswith('gz'):
fopen = gzip.open
elif path.endswith('xz'):
fopen = lzma.open
else:
fopen = open
return fopen(path, *args)
def load_test(path: str) -> List[str]:
data = []
logger.debug(f'Loading {path}')
with open_file(path, 'rt') as f:
for line in tqdm(f):
line = line.strip()
data.append(line)
return data
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO)
transformers_logger = logging.getLogger('transformers')
transformers_logger.setLevel(logging.WARNING)
for model_name in ['outputs-xmlr_large-512']:
model_dir = os.path.join(model_name, 'best_model')
seq_len = 512
logger.info(f'Processing {model_name} (for sequence length: {seq_len})')
if 'base' in model_name:
model_type = 'base'
elif 'large' in model_name:
model_type = 'large'
else:
raise ValueError(f'Unknown model type in name {model_name}')
args = {
'max_seq_length': seq_len,
'eval_batch_size': 35,
'reprocess_input_data': True,
'sliding_window': False,
}
model = ClassificationModel('xlmroberta', f'{model_dir}',
num_labels=2, args=args,
use_cuda=True, cuda_device=0)
output_name = f'model=xlmr_{model_type}-seq_len={seq_len}'
for test_name in ['dev-0', 'dev-1', 'test-A']:
logger.info(f'Processing {test_name}')
test_data = load_test(f'data/{test_name}/in.tsv')
save_path = f'data/{test_name}/out-{output_name}.tsv'
class_predictions, raw_outputs = model.predict(test_data)
softmax_tensor = torch.nn.functional.softmax(torch.tensor(raw_outputs), dim=1)
logger.info(f'Saving predictions into {save_path}')
with open_file(save_path, 'wt') as w:
for line_id in range(softmax_tensor.size(0)):
line_probs = softmax_tensor[line_id]
# Get second class for (M class)
w.write(f'{line_probs[1].item()}\n')