#!/usr/bin/env python3 # -*- coding: utf-8 -*- import gzip import logging import lzma from typing import Optional import pandas as pd from tqdm import tqdm from simpletransformers.classification import ClassificationModel logger = logging.getLogger(__name__) def open_file(path, *args): if path.endswith('gz'): fopen = gzip.open elif path.endswith('xz'): fopen = lzma.open else: fopen = open return fopen(path, *args) def load_train(path: str, max_lines: Optional[int] = None) -> pd.DataFrame: """ Load train/validate data. Args: path: file path max_lines: optional number of lines to read Returns: loaded data """ data = [] logger.info(f'Loading {path}') with open_file(path, 'rt') as f: for i, line in enumerate(tqdm(f)): line = line.strip() if '\t' not in line: logger.error(f'Found empty line at position {i + 1}' f' - SKIP THIS LINE') continue label_name, text = line.split('\t', maxsplit=2) text = text.strip() # LABEL should be string number label_id = int(label_name) data.append((text, label_id)) if max_lines is not None and i >= max_lines: break return pd.DataFrame(data) if __name__ == '__main__': logging.basicConfig(level=logging.INFO) transformers_logger = logging.getLogger('transformers') max_lines = None train_df = load_train('data/train/train.tsv', max_lines=max_lines) eval_df = load_train('data/dev-0/data.tsv') seq = 512 model_name = 'xmlr_base' args = { 'cache_dir': f'cache_dir-{model_name}/', 'output_dir': f'outputs-{model_name}-{seq}/', 'best_model_dir': f'outputs-{model_name}-{seq}/best_model', 'max_seq_length': seq, 'train_batch_size': 25, 'num_train_epochs': 1, 'evaluate_during_training': True, 'save_steps': 5000, 'evaluate_during_training_steps': 5000, 'use_cached_eval_features': True, 'reprocess_input_data': False, } model = ClassificationModel('xlmroberta', 'xlm-roberta-base', args=args, num_labels=2, use_cuda=True, cuda_device=0) logger.info(f'START TRAINING | ARGS: {model.args}') model.train_model(train_df, eval_df=eval_df)