petite-difference-challenge2/train.py

101 lines
2.5 KiB
Python
Raw Normal View History

2020-05-02 14:15:21 +02:00
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import gzip
import logging
import lzma
from typing import List, Optional
import pandas as pd
from simpletransformers.classification import ClassificationModel
from tqdm import tqdm
MAPPER_LABERL2ID = {
'F': 0,
'M': 1,
}
MAPPER_ID2LABEL = dict([(v, k) for k, v in MAPPER_LABERL2ID.items()])
logger = logging.getLogger(__name__)
def open_file(path, *args):
if path.endswith('gz'):
fopen = gzip.open
elif path.endswith('xz'):
fopen = lzma.open
else:
fopen = open
return fopen(path, *args)
def load_train(path: str, max_lines: Optional[int] = None) -> pd.DataFrame:
data = []
logger.info(f'Loading {path}')
with open_file(path, 'rt') as f:
for i, line in enumerate(tqdm(f)):
line = line.strip()
label_name, *text = line.split('\t')
text = ' '.join(text).strip()
if label_name not in MAPPER_LABERL2ID:
logger.error(f'Invalid class label "{label_name}"'
f' in line {i}')
continue
label_id = MAPPER_LABERL2ID[label_name]
data.append((text, label_id))
if max_lines is not None and i >= max_lines:
break
return pd.DataFrame(data)
def load_test(path: str) -> List[str]:
data = []
logger.info(f'Loading {path}')
with open_file(path, 'rt') as f:
for line in tqdm(f):
line = line.strip()
data.append(line)
return data
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO)
transformers_logger = logging.getLogger('transformers')
# Change arguments
model_name = 'RoBERTa-small'
seq_len = 128
valid = 'dev-0'
max_lines = None
train_df = load_train('data/train/train.tsv.xz', max_lines=max_lines)
eval_df = load_train(f'data/{valid}/data.tsv')
args = {
'max_seq_length': seq_len,
'train_batch_size': 100,
'eval_batch_size': 100,
'num_train_epochs': 5,
'evaluate_during_training': True,
'save_steps': 2500,
'evaluate_during_training_steps': 2500,
'use_cached_eval_features': True,
'evaluate_during_training_verbose': True,
'reprocess_input_data': False,
}
model = ClassificationModel('roberta', model_name,
num_labels=2, args=args)
logger.info('START TRAINING')
logger.info(f'TRAINING ARGS: {model.args}')
model.train_model(train_df, eval_df=eval_df)