#!/usr/bin/env python3 # -*- coding: utf-8 -*- import gzip import logging import lzma from typing import List, Optional import pandas as pd from simpletransformers.classification import ClassificationModel from tqdm import tqdm MAPPER_LABERL2ID = { 'F': 0, 'M': 1, } MAPPER_ID2LABEL = dict([(v, k) for k, v in MAPPER_LABERL2ID.items()]) logger = logging.getLogger(__name__) def open_file(path, *args): if path.endswith('gz'): fopen = gzip.open elif path.endswith('xz'): fopen = lzma.open else: fopen = open return fopen(path, *args) def load_train(path: str, max_lines: Optional[int] = None) -> pd.DataFrame: data = [] logger.info(f'Loading {path}') with open_file(path, 'rt') as f: for i, line in enumerate(tqdm(f)): line = line.strip() label_name, *text = line.split('\t') text = ' '.join(text).strip() if label_name not in MAPPER_LABERL2ID: logger.error(f'Invalid class label "{label_name}"' f' in line {i}') continue label_id = MAPPER_LABERL2ID[label_name] data.append((text, label_id)) if max_lines is not None and i >= max_lines: break return pd.DataFrame(data) def load_test(path: str) -> List[str]: data = [] logger.info(f'Loading {path}') with open_file(path, 'rt') as f: for line in tqdm(f): line = line.strip() data.append(line) return data if __name__ == '__main__': logging.basicConfig(level=logging.INFO) transformers_logger = logging.getLogger('transformers') # Change arguments model_name = 'RoBERTa-small' seq_len = 128 valid = 'dev-0' max_lines = None train_df = load_train('data/train/train.tsv.xz', max_lines=max_lines) eval_df = load_train(f'data/{valid}/data.tsv') args = { 'max_seq_length': seq_len, 'train_batch_size': 100, 'eval_batch_size': 100, 'num_train_epochs': 5, 'evaluate_during_training': True, 'save_steps': 2500, 'evaluate_during_training_steps': 2500, 'use_cached_eval_features': True, 'evaluate_during_training_verbose': True, 'reprocess_input_data': False, } model = ClassificationModel('roberta', model_name, num_labels=2, args=args) logger.info('START TRAINING') logger.info(f'TRAINING ARGS: {model.args}') model.train_model(train_df, eval_df=eval_df)