#!/usr/bin/env python3 # -*- coding: utf-8 -*- import gzip import logging import lzma from typing import List from simpletransformers.classification import ClassificationModel from tqdm import tqdm MAPPER_LABERL2ID = { 'F': 0, 'M': 1, } MAPPER_ID2LABEL = dict([(v, k) for k, v in MAPPER_LABERL2ID.items()]) logger = logging.getLogger(__name__) def open_file(path, *args): if path.endswith('gz'): fopen = gzip.open elif path.endswith('xz'): fopen = lzma.open else: fopen = open return fopen(path, *args) def load_test(path: str) -> List[str]: data = [] logger.info(f'Loading {path}') with open_file(path, 'rt') as f: for line in tqdm(f): line = line.strip() data.append(line) return data if __name__ == '__main__': logging.basicConfig(level=logging.INFO) transformers_logger = logging.getLogger('transformers') #transformers_logger.setLevel(logging.WARNING) args = { 'train_batch_size': 200, 'num_train_epochs': 2, 'evaluate_during_training': True, 'save_steps': 15000, 'evaluate_during_training_steps': 15000, } model = ClassificationModel('roberta', 'outputs/best_model', num_labels=2, args=args) output_name = 'model=base' for test_name in ['dev-0', 'dev-1', 'test-A']: logger.info(f'START TESTING {test_name}') test_data = load_test(f'data/{test_name}/in.tsv.xz') predictions, raw_outputs = model.predict(test_data) logger.info(f'Saving predictions') with open_file(f'data/{test_name}/out-{output_name}.tsv', 'wt') as w: for prediction in predictions: label_name = MAPPER_ID2LABEL[prediction] w.write(f'{label_name}\n')