Add RoBERTa classifier

This commit is contained in:
Karol Kaczmarek 2020-04-17 08:53:46 +02:00
parent 943a7c8c78
commit fe340ca26e
5 changed files with 428605 additions and 0 deletions

BIN
best_model.tar.xz Normal file

Binary file not shown.

137314
dev-0/out-model=base.tsv Normal file

File diff suppressed because it is too large Load Diff

156606
dev-1/out-model=base.tsv Normal file

File diff suppressed because it is too large Load Diff

67
eval.py Executable file
View File

@ -0,0 +1,67 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import gzip
import logging
import lzma
from typing import List
from simpletransformers.classification import ClassificationModel
from tqdm import tqdm
MAPPER_LABERL2ID = {
'F': 0,
'M': 1,
}
MAPPER_ID2LABEL = dict([(v, k) for k, v in MAPPER_LABERL2ID.items()])
logger = logging.getLogger(__name__)
def open_file(path, *args):
if path.endswith('gz'):
fopen = gzip.open
elif path.endswith('xz'):
fopen = lzma.open
else:
fopen = open
return fopen(path, *args)
def load_test(path: str) -> List[str]:
data = []
logger.info(f'Loading {path}')
with open_file(path, 'rt') as f:
for line in tqdm(f):
line = line.strip()
data.append(line)
return data
if __name__ == '__main__':
logging.basicConfig(level=logging.INFO)
transformers_logger = logging.getLogger('transformers')
#transformers_logger.setLevel(logging.WARNING)
args = {
'train_batch_size': 200,
'num_train_epochs': 2,
'evaluate_during_training': True,
'save_steps': 15000,
'evaluate_during_training_steps': 15000,
}
model = ClassificationModel('roberta', 'outputs/best_model',
num_labels=2, args=args)
output_name = 'model=base'
for test_name in ['dev-0', 'dev-1', 'test-A']:
logger.info(f'START TESTING {test_name}')
test_data = load_test(f'data/{test_name}/in.tsv.xz')
predictions, raw_outputs = model.predict(test_data)
logger.info(f'Saving predictions')
with open_file(f'data/{test_name}/out-{output_name}.tsv', 'wt') as w:
for prediction in predictions:
label_name = MAPPER_ID2LABEL[prediction]
w.write(f'{label_name}\n')

134618
test-A/out-model=base.tsv Normal file

File diff suppressed because it is too large Load Diff