petite-difference-challenge.../1-train-base.py
Karol Kaczmarek 8ce9cb5dac XLM RoBERTa
2020-06-14 18:35:22 +02:00

92 lines
2.4 KiB
Python
Executable File

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import gzip
import logging
import lzma
from typing import Optional
import pandas as pd
from tqdm import tqdm
from simpletransformers.classification import ClassificationModel
logger = logging.getLogger(__name__)
def open_file(path, *args):
if path.endswith('gz'):
fopen = gzip.open
elif path.endswith('xz'):
fopen = lzma.open
else:
fopen = open
return fopen(path, *args)
def load_train(path: str, max_lines: Optional[int] = None) -> pd.DataFrame:
"""
Load train/validate data.
Args:
path: file path
max_lines: optional number of lines to read
Returns:
loaded data
"""
data = []
logger.info(f'Loading {path}')
with open_file(path, 'rt') as f:
for i, line in enumerate(tqdm(f)):
line = line.strip()
if '\t' not in line:
logger.error(f'Found empty line at position {i + 1}'
f' - SKIP THIS LINE')
continue
label_name, text = line.split('\t', maxsplit=2)
text = text.strip()
# LABEL should be string number
label_id = int(label_name)
data.append((text, label_id))
if max_lines is not None and i >= max_lines:
break
return pd.DataFrame(data)
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO)
transformers_logger = logging.getLogger('transformers')
max_lines = None
train_df = load_train('data/train/train.tsv', max_lines=max_lines)
eval_df = load_train('data/dev-0/data.tsv')
seq = 512
model_name = 'xmlr_base'
args = {
'cache_dir': f'cache_dir-{model_name}/',
'output_dir': f'outputs-{model_name}-{seq}/',
'best_model_dir': f'outputs-{model_name}-{seq}/best_model',
'max_seq_length': seq,
'train_batch_size': 25,
'num_train_epochs': 1,
'evaluate_during_training': True,
'save_steps': 5000,
'evaluate_during_training_steps': 5000,
'use_cached_eval_features': True,
'reprocess_input_data': False,
}
model = ClassificationModel('xlmroberta', 'xlm-roberta-base', args=args,
num_labels=2, use_cuda=True, cuda_device=0)
logger.info(f'START TRAINING | ARGS: {model.args}')
model.train_model(train_df, eval_df=eval_df)