76 lines
1.9 KiB
Python
76 lines
1.9 KiB
Python
|
#!/usr/bin/env python3
|
||
|
# -*- coding: utf-8 -*-
|
||
|
|
||
|
import gzip
|
||
|
import logging
|
||
|
import lzma
|
||
|
from typing import List
|
||
|
|
||
|
from simpletransformers.classification import ClassificationModel
|
||
|
from tqdm import tqdm
|
||
|
|
||
|
MAPPER_LABERL2ID = {
|
||
|
'F': 0,
|
||
|
'M': 1,
|
||
|
}
|
||
|
MAPPER_ID2LABEL = dict([(v, k) for k, v in MAPPER_LABERL2ID.items()])
|
||
|
|
||
|
logger = logging.getLogger(__name__)
|
||
|
|
||
|
|
||
|
def open_file(path, *args):
|
||
|
if path.endswith('gz'):
|
||
|
fopen = gzip.open
|
||
|
elif path.endswith('xz'):
|
||
|
fopen = lzma.open
|
||
|
else:
|
||
|
fopen = open
|
||
|
return fopen(path, *args)
|
||
|
|
||
|
|
||
|
def load_test(path: str) -> List[str]:
|
||
|
data = []
|
||
|
|
||
|
logger.info(f'Loading {path}')
|
||
|
with open_file(path, 'rt') as f:
|
||
|
for line in tqdm(f):
|
||
|
line = line.strip()
|
||
|
data.append(line)
|
||
|
|
||
|
return data
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
logging.basicConfig(level=logging.INFO)
|
||
|
transformers_logger = logging.getLogger('transformers')
|
||
|
|
||
|
# Change arguments
|
||
|
seq_len = 128
|
||
|
model_size = 'small'
|
||
|
corpus = 'base'
|
||
|
sliding = False
|
||
|
valid = 'dev-0'
|
||
|
|
||
|
args = {
|
||
|
'max_seq_length': seq_len,
|
||
|
'eval_batch_size': 100,
|
||
|
'reprocess_input_data': True,
|
||
|
'sliding_window': sliding,
|
||
|
}
|
||
|
|
||
|
model = ClassificationModel('roberta', 'outputs/best_model',
|
||
|
num_labels=2, args=args)
|
||
|
output_name = f'model={model_size},corpus={corpus}'\
|
||
|
f',seq_len={seq_len},sliding={sliding}'\
|
||
|
f',valid={valid}'
|
||
|
|
||
|
for test_name in ['dev-0', 'dev-1', 'test-A']:
|
||
|
logger.info(f'START TESTING {test_name}')
|
||
|
test_data = load_test(f'data/{test_name}/in.tsv.xz')
|
||
|
predictions, raw_outputs = model.predict(test_data)
|
||
|
logger.info('Saving predictions')
|
||
|
with open_file(f'data/{test_name}/out-{output_name}.tsv', 'wt') as w:
|
||
|
for prediction in predictions:
|
||
|
label_name = MAPPER_ID2LABEL[prediction]
|
||
|
w.write(f'{label_name}\n')
|