#!/usr/bin/env python3 # -*- coding: utf-8 -*- import gzip import logging import lzma from typing import List from simpletransformers.classification import ClassificationModel from tqdm import tqdm MAPPER_LABERL2ID = { 'F': 0, 'M': 1, } MAPPER_ID2LABEL = dict([(v, k) for k, v in MAPPER_LABERL2ID.items()]) logger = logging.getLogger(__name__) def open_file(path, *args): if path.endswith('gz'): fopen = gzip.open elif path.endswith('xz'): fopen = lzma.open else: fopen = open return fopen(path, *args) def load_test(path: str) -> List[str]: data = [] logger.info(f'Loading {path}') with open_file(path, 'rt') as f: for line in tqdm(f): line = line.strip() data.append(line) return data if __name__ == '__main__': logging.basicConfig(level=logging.INFO) transformers_logger = logging.getLogger('transformers') # Change arguments seq_len = 128 model_size = 'small' corpus = 'base' sliding = False valid = 'dev-0' args = { 'max_seq_length': seq_len, 'eval_batch_size': 100, 'reprocess_input_data': True, 'sliding_window': sliding, } model = ClassificationModel('roberta', 'outputs/best_model', num_labels=2, args=args) output_name = f'model={model_size},corpus={corpus}'\ f',seq_len={seq_len},sliding={sliding}'\ f',valid={valid}' for test_name in ['dev-0', 'dev-1', 'test-A']: logger.info(f'START TESTING {test_name}') test_data = load_test(f'data/{test_name}/in.tsv.xz') predictions, raw_outputs = model.predict(test_data) logger.info('Saving predictions') with open_file(f'data/{test_name}/out-{output_name}.tsv', 'wt') as w: for prediction in predictions: label_name = MAPPER_ID2LABEL[prediction] w.write(f'{label_name}\n')