101 lines
2.5 KiB
Python
101 lines
2.5 KiB
Python
|
#!/usr/bin/env python3
|
||
|
# -*- coding: utf-8 -*-
|
||
|
|
||
|
import gzip
|
||
|
import logging
|
||
|
import lzma
|
||
|
from typing import List, Optional
|
||
|
|
||
|
import pandas as pd
|
||
|
from simpletransformers.classification import ClassificationModel
|
||
|
from tqdm import tqdm
|
||
|
|
||
|
MAPPER_LABERL2ID = {
|
||
|
'F': 0,
|
||
|
'M': 1,
|
||
|
}
|
||
|
MAPPER_ID2LABEL = dict([(v, k) for k, v in MAPPER_LABERL2ID.items()])
|
||
|
|
||
|
logger = logging.getLogger(__name__)
|
||
|
|
||
|
|
||
|
def open_file(path, *args):
|
||
|
if path.endswith('gz'):
|
||
|
fopen = gzip.open
|
||
|
elif path.endswith('xz'):
|
||
|
fopen = lzma.open
|
||
|
else:
|
||
|
fopen = open
|
||
|
return fopen(path, *args)
|
||
|
|
||
|
|
||
|
def load_train(path: str, max_lines: Optional[int] = None) -> pd.DataFrame:
|
||
|
data = []
|
||
|
|
||
|
logger.info(f'Loading {path}')
|
||
|
with open_file(path, 'rt') as f:
|
||
|
for i, line in enumerate(tqdm(f)):
|
||
|
line = line.strip()
|
||
|
|
||
|
label_name, *text = line.split('\t')
|
||
|
text = ' '.join(text).strip()
|
||
|
|
||
|
if label_name not in MAPPER_LABERL2ID:
|
||
|
logger.error(f'Invalid class label "{label_name}"'
|
||
|
f' in line {i}')
|
||
|
continue
|
||
|
|
||
|
label_id = MAPPER_LABERL2ID[label_name]
|
||
|
data.append((text, label_id))
|
||
|
|
||
|
if max_lines is not None and i >= max_lines:
|
||
|
break
|
||
|
|
||
|
return pd.DataFrame(data)
|
||
|
|
||
|
|
||
|
def load_test(path: str) -> List[str]:
|
||
|
data = []
|
||
|
|
||
|
logger.info(f'Loading {path}')
|
||
|
with open_file(path, 'rt') as f:
|
||
|
for line in tqdm(f):
|
||
|
line = line.strip()
|
||
|
data.append(line)
|
||
|
|
||
|
return data
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
logging.basicConfig(level=logging.INFO)
|
||
|
transformers_logger = logging.getLogger('transformers')
|
||
|
|
||
|
# Change arguments
|
||
|
model_name = 'RoBERTa-small'
|
||
|
seq_len = 128
|
||
|
valid = 'dev-0'
|
||
|
|
||
|
max_lines = None
|
||
|
train_df = load_train('data/train/train.tsv.xz', max_lines=max_lines)
|
||
|
eval_df = load_train(f'data/{valid}/data.tsv')
|
||
|
|
||
|
args = {
|
||
|
'max_seq_length': seq_len,
|
||
|
'train_batch_size': 100,
|
||
|
'eval_batch_size': 100,
|
||
|
'num_train_epochs': 5,
|
||
|
'evaluate_during_training': True,
|
||
|
'save_steps': 2500,
|
||
|
'evaluate_during_training_steps': 2500,
|
||
|
'use_cached_eval_features': True,
|
||
|
'evaluate_during_training_verbose': True,
|
||
|
'reprocess_input_data': False,
|
||
|
}
|
||
|
|
||
|
model = ClassificationModel('roberta', model_name,
|
||
|
num_labels=2, args=args)
|
||
|
|
||
|
logger.info('START TRAINING')
|
||
|
logger.info(f'TRAINING ARGS: {model.args}')
|
||
|
model.train_model(train_df, eval_df=eval_df)
|