92 lines
2.4 KiB
Python
Executable File
92 lines
2.4 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
|
|
import gzip
|
|
import logging
|
|
import lzma
|
|
from typing import Optional
|
|
|
|
import pandas as pd
|
|
from tqdm import tqdm
|
|
|
|
from simpletransformers.classification import ClassificationModel
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def open_file(path, *args):
|
|
if path.endswith('gz'):
|
|
fopen = gzip.open
|
|
elif path.endswith('xz'):
|
|
fopen = lzma.open
|
|
else:
|
|
fopen = open
|
|
return fopen(path, *args)
|
|
|
|
|
|
def load_train(path: str, max_lines: Optional[int] = None) -> pd.DataFrame:
|
|
"""
|
|
Load train/validate data.
|
|
|
|
Args:
|
|
path: file path
|
|
max_lines: optional number of lines to read
|
|
|
|
Returns:
|
|
loaded data
|
|
|
|
"""
|
|
data = []
|
|
|
|
logger.info(f'Loading {path}')
|
|
with open_file(path, 'rt') as f:
|
|
for i, line in enumerate(tqdm(f)):
|
|
line = line.strip()
|
|
if '\t' not in line:
|
|
logger.error(f'Found empty line at position {i + 1}'
|
|
f' - SKIP THIS LINE')
|
|
continue
|
|
|
|
label_name, text = line.split('\t', maxsplit=2)
|
|
text = text.strip()
|
|
|
|
# LABEL should be string number
|
|
label_id = int(label_name)
|
|
data.append((text, label_id))
|
|
|
|
if max_lines is not None and i >= max_lines:
|
|
break
|
|
|
|
return pd.DataFrame(data)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
logging.basicConfig(level=logging.INFO)
|
|
transformers_logger = logging.getLogger('transformers')
|
|
|
|
max_lines = None
|
|
train_df = load_train('data/train/train.tsv', max_lines=max_lines)
|
|
eval_df = load_train('data/dev-0/data.tsv')
|
|
|
|
seq = 512
|
|
model_name = 'xmlr_large'
|
|
args = {
|
|
'cache_dir': f'cache_dir-{model_name}/',
|
|
'output_dir': f'outputs-{model_name}-{seq}/',
|
|
'best_model_dir': f'outputs-{model_name}-{seq}/best_model',
|
|
'max_seq_length': seq,
|
|
'train_batch_size': 10,
|
|
'num_train_epochs': 1,
|
|
'evaluate_during_training': True,
|
|
'save_steps': 5000,
|
|
'evaluate_during_training_steps': 5000,
|
|
'use_cached_eval_features': True,
|
|
'reprocess_input_data': False,
|
|
}
|
|
|
|
model = ClassificationModel('xlmroberta', 'xlm-roberta-large', args=args,
|
|
num_labels=2, use_cuda=True, cuda_device=0)
|
|
|
|
logger.info(f'START TRAINING | ARGS: {model.args}')
|
|
model.train_model(train_df, eval_df=eval_df)
|