Add RoBERTa classifier
This commit is contained in:
parent
943a7c8c78
commit
fe340ca26e
BIN
best_model.tar.xz
Normal file
BIN
best_model.tar.xz
Normal file
Binary file not shown.
137314
dev-0/out-model=base.tsv
Normal file
137314
dev-0/out-model=base.tsv
Normal file
File diff suppressed because it is too large
Load Diff
156606
dev-1/out-model=base.tsv
Normal file
156606
dev-1/out-model=base.tsv
Normal file
File diff suppressed because it is too large
Load Diff
67
eval.py
Executable file
67
eval.py
Executable file
@ -0,0 +1,67 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import gzip
|
||||
import logging
|
||||
import lzma
|
||||
from typing import List
|
||||
|
||||
from simpletransformers.classification import ClassificationModel
|
||||
from tqdm import tqdm
|
||||
|
||||
MAPPER_LABERL2ID = {
|
||||
'F': 0,
|
||||
'M': 1,
|
||||
}
|
||||
MAPPER_ID2LABEL = dict([(v, k) for k, v in MAPPER_LABERL2ID.items()])
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def open_file(path, *args):
|
||||
if path.endswith('gz'):
|
||||
fopen = gzip.open
|
||||
elif path.endswith('xz'):
|
||||
fopen = lzma.open
|
||||
else:
|
||||
fopen = open
|
||||
return fopen(path, *args)
|
||||
|
||||
|
||||
def load_test(path: str) -> List[str]:
|
||||
data = []
|
||||
|
||||
logger.info(f'Loading {path}')
|
||||
with open_file(path, 'rt') as f:
|
||||
for line in tqdm(f):
|
||||
line = line.strip()
|
||||
data.append(line)
|
||||
|
||||
return data
|
||||
|
||||
if __name__ == '__main__':
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
transformers_logger = logging.getLogger('transformers')
|
||||
#transformers_logger.setLevel(logging.WARNING)
|
||||
|
||||
args = {
|
||||
'train_batch_size': 200,
|
||||
'num_train_epochs': 2,
|
||||
'evaluate_during_training': True,
|
||||
'save_steps': 15000,
|
||||
'evaluate_during_training_steps': 15000,
|
||||
}
|
||||
|
||||
model = ClassificationModel('roberta', 'outputs/best_model',
|
||||
num_labels=2, args=args)
|
||||
output_name = 'model=base'
|
||||
|
||||
for test_name in ['dev-0', 'dev-1', 'test-A']:
|
||||
logger.info(f'START TESTING {test_name}')
|
||||
test_data = load_test(f'data/{test_name}/in.tsv.xz')
|
||||
predictions, raw_outputs = model.predict(test_data)
|
||||
logger.info(f'Saving predictions')
|
||||
with open_file(f'data/{test_name}/out-{output_name}.tsv', 'wt') as w:
|
||||
for prediction in predictions:
|
||||
label_name = MAPPER_ID2LABEL[prediction]
|
||||
w.write(f'{label_name}\n')
|
134618
test-A/out-model=base.tsv
Normal file
134618
test-A/out-model=base.tsv
Normal file
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue
Block a user